Unverified Commit d82cfdc0 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add feature fetcher udf (#6210)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 5c1206de
...@@ -19,25 +19,23 @@ class DataBlock: ...@@ -19,25 +19,23 @@ class DataBlock:
consistency and ease of use throughout the loading process.""" consistency and ease of use throughout the loading process."""
sampled_subgraphs: List[SampledSubgraph] = None sampled_subgraphs: List[SampledSubgraph] = None
""" """A list of 'SampledSubgraph's, each one corresponding to one layer,
A list of 'SampledSubgraph's, each one corresponding to one layer,
representing a subset of a larger graph structure. representing a subset of a larger graph structure.
""" """
node_feature: Union[torch.Tensor, Dict[str, torch.Tensor]] = None node_feature: Dict[Tuple[str, str], torch.Tensor] = None
"""A representation of node feature. """A representation of node features.
- If `node_feature` is a tensor: It indicates the graph is homogeneous. Keys are tuples of '(node_type, feature_name)' and the values are
- If `node_feature` is a dictionary: The keys should be node type and the corresponding features. Note that for a homogeneous graph, where there are
value should be corresponding node feature or embedding. no node types, 'node_type' should be None.
""" """
edge_feature: List[ edge_feature: List[Dict[Tuple[str, str], torch.Tensor]] = None
Union[torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]] """Edge features associated with the 'sampled_subgraphs'.
] = None The keys are tuples in the format '(edge_type, feature_name)', and the
"""A representation of edge feature corresponding to 'sampled_subgraphs'. values represent the corresponding features. In the case of a homogeneous
- If `edge_feature` is a tensor: It indicates the graph is homogeneous. graph where no edge types exist, 'edge_type' should be set to None.
- If `edge_feature` is a dictionary: The keys should be edge type and the Note 'edge_type' are of format 'str:str:str'.
value should be corresponding edge feature or embedding.
""" """
input_nodes: Union[ input_nodes: Union[
......
...@@ -4,19 +4,77 @@ from torchdata.datapipes.iter import Mapper ...@@ -4,19 +4,77 @@ from torchdata.datapipes.iter import Mapper
class FeatureFetcher(Mapper): class FeatureFetcher(Mapper):
"""Base feature fetcher. """A feature fetcher used to fetch features for node/edge in graphbolt."""
This is equivalent to the following iterator: def __init__(self, datapipe, feature_store, feature_keys):
"""
Initlization for a feature fetcher.
.. code:: python Parameters
----------
datapipe : DataPipe
The datapipe.
feature_store : FeatureStore
A storage for features, support read and update.
feature_keys : (str, str, str)
Features need to be read, with each feature being uniquely identified
by a triplet '(domain, type_name, feature_name)'.
"""
super().__init__(datapipe, self._read)
self.feature_store = feature_store
self.feature_keys = feature_keys
for data in datapipe: def _read(self, data):
yield feature_fetch_func(data) """
Fill in the node/edge features field in data.
Parameters Parameters
---------- ----------
datapipe : DataPipe data : DataBlock
The datapipe. An instance of the 'DataBlock' class. Even if 'node_feature' or
fn : callable 'edge_feature' is already filled, it will be overwritten for
The function that performs feature fetching. overlapping features.
"""
Returns
-------
DataBlock
An instance of 'DataBlock' filled with required features.
"""
data.node_feature = {}
num_layer = len(data.sampled_subgraphs) if data.sampled_subgraphs else 0
data.edge_feature = [{} for _ in range(num_layer)]
for key in self.feature_keys:
domain, type_name, feature_name = key
if domain == "node" and data.input_nodes is not None:
nodes = (
data.input_nodes
if not type_name
else data.input_nodes[type_name]
)
if nodes is not None:
data.node_feature[
(type_name, feature_name)
] = self.feature_store.read(
domain,
type_name,
feature_name,
nodes,
)
elif domain == "edge" and data.sampled_subgraphs is not None:
for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None:
edges = (
subgraph.reverse_edge_ids
if not type_name
# TODO(#6211): Clean up the edge type converter.
else subgraph.reverse_edge_ids.get(
tuple(type_name.split(":")), None
)
)
if edges is not None:
data.edge_feature[i][
(type_name, feature_name)
] = self.feature_store.read(
domain, type_name, feature_name, edges
)
return data
...@@ -8,6 +8,11 @@ import scipy.sparse as sp ...@@ -8,6 +8,11 @@ import scipy.sparse as sp
import torch import torch
def to_node_block(data):
block = gb.NodeClassificationBlock(seed_node=data)
return block
def rand_csc_graph(N, density): def rand_csc_graph(N, density):
adj = sp.random(N, N, density) adj = sp.random(N, N, density)
adj = adj + adj.T adj = adj + adj.T
......
import dgl import dgl.graphbolt as gb
import dgl.graphbolt import gb_test_utils
import pytest
import torch import torch
from torchdata.datapipes.iter import Mapper
def get_graphbolt_fetch_func(): def test_FeatureFetcher_homo():
feature_store = { graph = gb_test_utils.rand_csc_graph(20, 0.15)
"feature": dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)), a = torch.randint(0, 10, (graph.num_nodes,))
"label": dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))), b = torch.randint(0, 10, (graph.num_edges,))
}
def fetch_func(data): features = {}
return feature_store["feature"].read(data), feature_store["label"].read( keys = [("node", None, "a"), ("edge", None, "b")]
data features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
keys.append(("node", None, "c"))
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, keys)
assert len(list(fetcher_dp)) == 5
def test_FeatureFetcher_with_edges_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
b = torch.randint(0, 10, (graph.num_edges,))
def add_node_and_edge_ids(seeds):
subgraphs = []
for _ in range(3):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)),
)
)
data = gb.NodeClassificationBlock(
input_nodes=seeds, sampled_subgraphs=subgraphs
) )
return data
features = {}
keys = [("node", None, "a"), ("edge", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, keys)
assert len(list(fetcher_dp)) == 5
for data in fetcher_dp:
assert data.node_feature[(None, "a")].size(0) == 2
assert len(data.edge_feature) == 3
for edge_feature in data.edge_feature:
assert edge_feature[(None, "b")].size(0) == 10
def get_hetero_graph():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
return fetch_func
def test_FeatureFetcher_hetero():
graph = get_hetero_graph()
a = torch.randint(0, 10, (2,))
b = torch.randint(0, 10, (3,))
def get_tensor_fetch_func(): features = {}
feature_store = torch.randn(200, 4) keys = [("node", "n1", "a"), ("node", "n2", "a")]
label = torch.randint(0, 10, (200,)) features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
def fetch_func(data): itemset = gb.ItemSetDict(
return feature_store[data], label[data] {
"n1": gb.ItemSet(torch.LongTensor([0, 1])),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2])),
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, keys)
return fetch_func assert len(list(fetcher_dp)) == 3
def test_FeatureFetcher_with_edges_hetero():
a = torch.randint(0, 10, (20,))
b = torch.randint(0, 10, (50,))
def add_node_and_edge_ids(seeds):
subgraphs = []
reverse_edge_ids = {
("n1", "e1", "n2"): torch.randint(0, 50, (10,)),
("n2", "e2", "n1"): torch.randint(0, 50, (10,)),
}
for _ in range(3):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
reverse_edge_ids=reverse_edge_ids,
)
)
data = gb.NodeClassificationBlock(
input_nodes=seeds, sampled_subgraphs=subgraphs
)
return data
features = {}
keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
@pytest.mark.parametrize( itemset = gb.ItemSetDict(
"fetch_func", [get_graphbolt_fetch_func(), get_tensor_fetch_func()] {
) "n1": gb.ItemSet(torch.randint(0, 20, (10,))),
def test_FeatureFetcher(fetch_func): }
itemset = dgl.graphbolt.ItemSet(torch.arange(10)) )
minibatch_dp = dgl.graphbolt.MinibatchSampler(itemset, batch_size=2) minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
fetcher_dp = dgl.graphbolt.FeatureFetcher(minibatch_dp, fetch_func) converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, keys)
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
for data in fetcher_dp:
assert data.node_feature[("n1", "a")].size(0) == 2
assert len(data.edge_feature) == 3
for edge_feature in data.edge_feature:
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
...@@ -11,23 +11,6 @@ import torch ...@@ -11,23 +11,6 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def to_node_block(data):
block = dgl.graphbolt.NodeClassificationBlock(seed_node=data)
return block
def to_tuple(data):
output_nodes = data.sampled_subgraphs[-1].reverse_column_node_ids
return data.input_nodes, output_nodes, data.sampled_subgraphs
def fetch_func(features, labels, data):
input_nodes, output_nodes, adjs = data
input_features = features.read(input_nodes)
output_labels = labels.read(output_nodes)
return input_features, output_labels, adjs
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
# TODO (peizhou): Will enable windows test once CSCSamplingraph is pickleable. # TODO (peizhou): Will enable windows test once CSCSamplingraph is pickleable.
def test_DataLoader(): def test_DataLoader():
...@@ -35,20 +18,23 @@ def test_DataLoader(): ...@@ -35,20 +18,23 @@ def test_DataLoader():
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N)) itemset = dgl.graphbolt.ItemSet(torch.arange(N))
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)) features = {}
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))) keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
feature_store = dgl.graphbolt.BasicFeatureStore(features)
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
block_converter = Mapper(minibatch_sampler, to_node_block) block_converter = Mapper(minibatch_sampler, gb_test_utils.to_node_block)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter, block_converter,
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
tuple_converter = Mapper(subgraph_sampler, to_tuple)
feature_fetcher = dgl.graphbolt.FeatureFetcher( feature_fetcher = dgl.graphbolt.FeatureFetcher(
tuple_converter, subgraph_sampler,
partial(fetch_func, features, labels), feature_store,
keys,
) )
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
......
...@@ -12,24 +12,17 @@ def to_node_block(data): ...@@ -12,24 +12,17 @@ def to_node_block(data):
return block return block
def to_tuple(data):
output_nodes = data.sampled_subgraphs[-1].reverse_column_node_ids
return data.input_nodes, output_nodes, data.sampled_subgraphs
def test_DataLoader(): def test_DataLoader():
N = 32 N = 32
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N)) itemset = dgl.graphbolt.ItemSet(torch.arange(N))
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
def fetch_func(data): features = {}
input_nodes, output_nodes, adjs = data keys = [("node", None, "a"), ("node", None, "b")]
input_features = features.read(input_nodes) features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
output_labels = labels.read(output_nodes) features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
return input_features, output_labels, adjs feature_store = dgl.graphbolt.BasicFeatureStore(features)
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
block_converter = Mapper(minibatch_sampler, to_node_block) block_converter = Mapper(minibatch_sampler, to_node_block)
...@@ -38,8 +31,9 @@ def test_DataLoader(): ...@@ -38,8 +31,9 @@ def test_DataLoader():
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
tuple_converter = Mapper(subgraph_sampler, to_tuple) feature_fetcher = dgl.graphbolt.FeatureFetcher(
feature_fetcher = dgl.graphbolt.FeatureFetcher(tuple_converter, fetch_func) subgraph_sampler, feature_store, keys
)
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
dataloader = dgl.graphbolt.SingleProcessDataLoader(device_transferrer) dataloader = dgl.graphbolt.SingleProcessDataLoader(device_transferrer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment