Unverified Commit ce66280d authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Reorganzie graphbolt module (#5711)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-25-242.ap-northeast-1.compute.internal>
parent 72d16f78
...@@ -19,7 +19,6 @@ from . import ( ...@@ -19,7 +19,6 @@ from . import (
container, container,
cuda, cuda,
dataloading, dataloading,
dataloading2,
distributed, distributed,
function, function,
ops, ops,
...@@ -48,7 +47,6 @@ from .dataloading import ( ...@@ -48,7 +47,6 @@ from .dataloading import (
set_node_lazy_features, set_node_lazy_features,
set_src_lazy_features, set_src_lazy_features,
) )
from .dataloading2 import *
from .heterograph import ( # pylint: disable=reimported from .heterograph import ( # pylint: disable=reimported
DGLGraph, DGLGraph,
DGLGraph as DGLHeteroGraph, DGLGraph as DGLHeteroGraph,
......
"""Graph Bolt data fetcher base class"""
"""Graph Bolt minibatch sampler base class"""
"""Graph Bolt negative sampler base class"""
"""Graph Bolt subgraph sampler base class"""
import dgl import dgl
import torch import torch
from torch.testing import assert_close from torch.testing import assert_close
from dgl.graphbolt import *
def test_ItemSet_node_edge_ids(): def test_ItemSet_node_edge_ids():
# Node or edge IDs. # Node or edge IDs.
item_set = dgl.ItemSet(torch.arange(0, 5)) item_set = ItemSet(torch.arange(0, 5))
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
...@@ -13,7 +14,7 @@ def test_ItemSet_node_edge_ids(): ...@@ -13,7 +14,7 @@ def test_ItemSet_node_edge_ids():
def test_ItemSet_graphs(): def test_ItemSet_graphs():
# Graphs. # Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)] graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = dgl.ItemSet(graphs) item_set = ItemSet(graphs)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert graphs[i] == item assert graphs[i] == item
...@@ -21,7 +22,7 @@ def test_ItemSet_graphs(): ...@@ -21,7 +22,7 @@ def test_ItemSet_graphs():
def test_ItemSet_node_pairs(): def test_ItemSet_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = dgl.ItemSet(node_pairs) item_set = ItemSet(node_pairs)
for i, (src, dst) in enumerate(item_set): for i, (src, dst) in enumerate(item_set):
assert node_pairs[0][i] == src assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst assert node_pairs[1][i] == dst
...@@ -31,7 +32,7 @@ def test_ItemSet_node_pairs_labels(): ...@@ -31,7 +32,7 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = dgl.ItemSet((node_pairs[0], node_pairs[1], labels)) item_set = ItemSet((node_pairs[0], node_pairs[1], labels))
for i, (src, dst, label) in enumerate(item_set): for i, (src, dst, label) in enumerate(item_set):
assert node_pairs[0][i] == src assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst assert node_pairs[1][i] == dst
...@@ -43,7 +44,7 @@ def test_ItemSet_head_tail_neg_tails(): ...@@ -43,7 +44,7 @@ def test_ItemSet_head_tail_neg_tails():
heads = torch.arange(0, 5) heads = torch.arange(0, 5)
tails = torch.arange(5, 10) tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2) neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = dgl.ItemSet((heads, tails, neg_tails)) item_set = ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set): for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head assert heads[i] == head
assert tails[i] == tail assert tails[i] == tail
...@@ -53,13 +54,13 @@ def test_ItemSet_head_tail_neg_tails(): ...@@ -53,13 +54,13 @@ def test_ItemSet_head_tail_neg_tails():
def test_DictItemSet_node_edge_ids(): def test_DictItemSet_node_edge_ids():
# Node or edge IDs # Node or edge IDs
ids = { ids = {
("user", "like", "item"): dgl.ItemSet(torch.arange(0, 5)), ("user", "like", "item"): ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): dgl.ItemSet(torch.arange(0, 5)), ("user", "follow", "user"): ItemSet(torch.arange(0, 5)),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = dgl.DictItemSet(ids) item_set = DictItemSet(ids)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -71,13 +72,13 @@ def test_DictItemSet_node_pairs(): ...@@ -71,13 +72,13 @@ def test_DictItemSet_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): dgl.ItemSet(node_pairs), ("user", "like", "item"): ItemSet(node_pairs),
("user", "follow", "user"): dgl.ItemSet(node_pairs), ("user", "follow", "user"): ItemSet(node_pairs),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = dgl.DictItemSet(node_pairs_dict) item_set = DictItemSet(node_pairs_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -90,17 +91,17 @@ def test_DictItemSet_node_pairs_labels(): ...@@ -90,17 +91,17 @@ def test_DictItemSet_node_pairs_labels():
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): dgl.ItemSet( ("user", "like", "item"): ItemSet(
(node_pairs[0], node_pairs[1], labels) (node_pairs[0], node_pairs[1], labels)
), ),
("user", "follow", "user"): dgl.ItemSet( ("user", "follow", "user"): ItemSet(
(node_pairs[0], node_pairs[1], labels) (node_pairs[0], node_pairs[1], labels)
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = dgl.DictItemSet(node_pairs_dict) item_set = DictItemSet(node_pairs_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -113,15 +114,15 @@ def test_DictItemSet_head_tail_neg_tails(): ...@@ -113,15 +114,15 @@ def test_DictItemSet_head_tail_neg_tails():
heads = torch.arange(0, 5) heads = torch.arange(0, 5)
tails = torch.arange(5, 10) tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2) neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = dgl.ItemSet((heads, tails, neg_tails)) item_set = ItemSet((heads, tails, neg_tails))
data_dict = { data_dict = {
("user", "like", "item"): dgl.ItemSet((heads, tails, neg_tails)), ("user", "like", "item"): ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): dgl.ItemSet((heads, tails, neg_tails)), ("user", "follow", "user"): ItemSet((heads, tails, neg_tails)),
} }
expected_data = [] expected_data = []
for key, value in data_dict.items(): for key, value in data_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = dgl.DictItemSet(data_dict) item_set = DictItemSet(data_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment