Unverified Commit a0721405 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] don’t import dgl in the package. (#1382)



* fix dgl data.

* remove more.

* fix.

* fix.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-150.us-west-2.compute.internal>
parent 10253a5c
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
import datetime
from .utils import get_download_dir, download, extract_archive
from ..graph import DGLGraph
class BitcoinOTC(object):
......@@ -63,4 +63,4 @@ class BitcoinOTC(object):
@property
def is_temporal(self):
return True
\ No newline at end of file
return True
......@@ -10,10 +10,10 @@ import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import os, sys
from dgl import DGLGraph
import dgl
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
from ..graph import batch as graph_batch
_urls = {
'cora' : 'dataset/cora_raw.zip',
......@@ -314,13 +314,13 @@ class CoraBinary(object):
for line in f.readlines():
if line.startswith('graph'):
if len(elist) != 0:
self.graphs.append(dgl.DGLGraph(elist))
self.graphs.append(DGLGraph(elist))
elist = []
else:
u, v = line.strip().split(' ')
elist.append((int(u), int(v)))
if len(elist) != 0:
self.graphs.append(dgl.DGLGraph(elist))
self.graphs.append(DGLGraph(elist))
with open("{}/pmpds.pkl".format(root), 'rb') as f:
self.pmpds = _pickle_load(f)
self.labels = []
......@@ -348,7 +348,7 @@ class CoraBinary(object):
@staticmethod
def collate_fn(batch):
graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs)
batched_graphs = graph_batch(graphs)
batched_pmpds = sp.block_diag(pmpds)
batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels
......
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt
from ..graph import DGLGraph
class GDELT(object):
......
import scipy.sparse as sp
import numpy as np
from dgl import graph_index, DGLGraph, transform
import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
__all__=["AmazonCoBuy", "Coauthor", 'CoraFull']
......
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
import datetime
import warnings
from .utils import get_download_dir, download, extract_archive, loadtxt
from ..graph import DGLGraph
class ICEWS18(object):
......@@ -88,4 +88,4 @@ class ICEWS18(object):
@property
def is_temporal(self):
return True
\ No newline at end of file
return True
......@@ -2,7 +2,7 @@
"""
import numpy as np
import networkx as nx
from dgl import DGLGraph
from ..graph import DGLGraph
class KarateClub(object):
......
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
from .utils import get_download_dir, download
from ..graph import DGLGraph
class QM7b(object):
"""
......
......@@ -2,9 +2,9 @@ from __future__ import absolute_import
import scipy.sparse as sp
import numpy as np
import dgl
import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
class RedditDataset(object):
......@@ -19,7 +19,7 @@ class RedditDataset(object):
extract_archive(zip_file_path, extract_dir)
# graph
coo_adj = sp.load_npz(os.path.join(extract_dir, "reddit{}_graph.npz".format(self_loop_str)))
self.graph = dgl.DGLGraph(coo_adj, readonly=True)
self.graph = DGLGraph(coo_adj, readonly=True)
# features and labels
reddit_data = np.load(os.path.join(extract_dir, "reddit_data.npz"))
self.features = reddit_data["feature"]
......
......@@ -10,9 +10,10 @@ import networkx as nx
import numpy as np
import os
import dgl
import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from .. import backend as F
from ..graph import DGLGraph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
__all__ = ['SSTBatch', 'SST']
......@@ -116,7 +117,7 @@ class SST(object):
# add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root)
ret = dgl.DGLGraph()
ret = DGLGraph()
ret.from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret
......
from __future__ import absolute_import
import numpy as np
import dgl
import os
import random
from dgl.data.utils import download, extract_archive, get_download_dir, loadtxt
from .utils import download, extract_archive, get_download_dir, loadtxt
from ..graph import DGLGraph
class LegacyTUDataset(object):
"""
......@@ -46,7 +45,7 @@ class LegacyTUDataset(object):
DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int))
g = dgl.DGLGraph()
g = DGLGraph()
g.add_nodes(int(DS_edge_list.max()) + 1)
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
......@@ -180,7 +179,7 @@ class TUDataset(object):
DS_graph_labels = self._idx_from_zero(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
g = dgl.DGLGraph()
g = DGLGraph()
g.add_nodes(int(DS_edge_list.max()) + 1)
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment