"...graphsage/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "02d3197407487646dc5ca6abd889b8fe5fed1aef"
Unverified Commit 0b386a1d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

Merge Graphbolt to master (#5680)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-25-242.ap-northeast-1.compute.internal>
parent 41baa0e4
...@@ -19,6 +19,7 @@ from . import ( ...@@ -19,6 +19,7 @@ from . import (
container, container,
cuda, cuda,
dataloading, dataloading,
dataloading2,
distributed, distributed,
function, function,
ops, ops,
...@@ -47,6 +48,7 @@ from .dataloading import ( ...@@ -47,6 +48,7 @@ from .dataloading import (
set_node_lazy_features, set_node_lazy_features,
set_src_lazy_features, set_src_lazy_features,
) )
from .dataloading2 import *
from .heterograph import ( # pylint: disable=reimported from .heterograph import ( # pylint: disable=reimported
DGLGraph, DGLGraph,
DGLGraph as DGLHeteroGraph, DGLGraph as DGLHeteroGraph,
......
"""GraphBolt"""
from .itemset import *
"""Graph Bolt data fetcher base class"""
"""Graph Bolt DataLoaders"""
"""GraphBolt Itemset."""
__all__ = ["ItemSet", "DictItemSet"]
class ItemSet:
r"""An iterable itemset.
All itemsets that represent an iterable of items should subclass it. Such
form of itemset is particularly useful when items come from a stream. This
class requires each input itemset to be iterable.
Parameters
----------
items: Iterable or Tuple[Iterable]
"""
def __init__(self, items):
if isinstance(items, tuple):
self._items = items
else:
self._items = (items,)
def __iter__(self):
if len(self._items) == 1:
yield from self._items[0]
return
zip_items = zip(*self._items)
for item in zip_items:
yield tuple(item)
def __getitem__(self, _):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class DictItemSet:
r"""Itemset wrapping multiple itemsets with keys.
Each item is retrieved by iterating over each itemset and returned with
corresponding key as a dict.
Parameters
----------
itemsets: Dict[str, ItemSet]
"""
def __init__(self, itemsets):
self._itemsets = itemsets
def __iter__(self):
for key, itemset in self._itemsets.items():
for item in itemset:
yield {key: item}
def __getitem__(self, _):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
"""Graph Bolt minibatch sampler base class"""
"""Graph Bolt negative sampler base class"""
"""Graph Bolt subgraph sampler base class"""
import dgl
import torch
from torch.testing import assert_close
def test_ItemSet_node_edge_ids():
# Node or edge IDs.
item_set = dgl.ItemSet(torch.arange(0, 5))
for i, item in enumerate(item_set):
assert i == item.item()
def test_ItemSet_graphs():
# Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = dgl.ItemSet(graphs)
for i, item in enumerate(item_set):
assert graphs[i] == item
def test_ItemSet_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = dgl.ItemSet(node_pairs)
for i, (src, dst) in enumerate(item_set):
assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst
def test_ItemSet_node_pairs_labels():
# Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
item_set = dgl.ItemSet((node_pairs[0], node_pairs[1], labels))
for i, (src, dst, label) in enumerate(item_set):
assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst
assert labels[i] == label
def test_ItemSet_head_tail_neg_tails():
# Head, tail and negative tails.
heads = torch.arange(0, 5)
tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = dgl.ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head
assert tails[i] == tail
assert_close(neg_tails[i], negs)
def test_DictItemSet_node_edge_ids():
# Node or edge IDs
ids = {
("user", "like", "item"): dgl.ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): dgl.ItemSet(torch.arange(0, 5)),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = dgl.DictItemSet(ids)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert chained_ids[i][0] in item
assert item[chained_ids[i][0]] == chained_ids[i][1]
def test_DictItemSet_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = {
("user", "like", "item"): dgl.ItemSet(node_pairs),
("user", "follow", "user"): dgl.ItemSet(node_pairs),
}
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = dgl.DictItemSet(node_pairs_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert expected_data[i][0] in item
assert item[expected_data[i][0]] == expected_data[i][1]
def test_DictItemSet_node_pairs_labels():
# Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
node_pairs_dict = {
("user", "like", "item"): dgl.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
("user", "follow", "user"): dgl.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
}
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = dgl.DictItemSet(node_pairs_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert expected_data[i][0] in item
assert item[expected_data[i][0]] == expected_data[i][1]
def test_DictItemSet_head_tail_neg_tails():
# Head, tail and negative tails.
heads = torch.arange(0, 5)
tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = dgl.ItemSet((heads, tails, neg_tails))
data_dict = {
("user", "like", "item"): dgl.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): dgl.ItemSet((heads, tails, neg_tails)),
}
expected_data = []
for key, value in data_dict.items():
expected_data += [(key, v) for v in value]
item_set = dgl.DictItemSet(data_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert expected_data[i][0] in item
assert_close(item[expected_data[i][0]], expected_data[i][1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment