Unverified Commit 219c9f1a authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Refator feature fetcher (#6245)

parent 155608d3
...@@ -23,18 +23,25 @@ class DataBlock: ...@@ -23,18 +23,25 @@ class DataBlock:
representing a subset of a larger graph structure. representing a subset of a larger graph structure.
""" """
node_feature: Dict[Tuple[str, str], torch.Tensor] = None node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
] = None
"""A representation of node features. """A representation of node features.
Keys are tuples of '(node_type, feature_name)' and the values are - If keys are single strings: It means the graph is homogeneous, and the
corresponding features. Note that for a homogeneous graph, where there are keys are feature names.
no node types, 'node_type' should be None. - If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
""" """
edge_feature: List[Dict[Tuple[str, str], torch.Tensor]] = None edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
] = None
"""Edge features associated with the 'sampled_subgraphs'. """Edge features associated with the 'sampled_subgraphs'.
The keys are tuples in the format '(edge_type, feature_name)', and the - If keys are single strings: It means the graph is homogeneous, and the
values represent the corresponding features. In the case of a homogeneous keys are feature names.
graph where no edge types exist, 'edge_type' should be set to None. - If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is single
string of format 'str:str:str'.
""" """
input_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None input_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
......
"""Feature fetchers""" """Feature fetchers"""
from typing import Dict
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
class FeatureFetcher(Mapper): class FeatureFetcher(Mapper):
"""A feature fetcher used to fetch features for node/edge in graphbolt.""" """A feature fetcher used to fetch features for node/edge in graphbolt."""
def __init__(self, datapipe, feature_store, feature_keys): def __init__(
self,
datapipe,
feature_store,
node_feature_keys=None,
edge_feature_keys=None,
):
""" """
Initlization for a feature fetcher. Initlization for a feature fetcher.
...@@ -16,13 +24,24 @@ class FeatureFetcher(Mapper): ...@@ -16,13 +24,24 @@ class FeatureFetcher(Mapper):
The datapipe. The datapipe.
feature_store : FeatureStore feature_store : FeatureStore
A storage for features, support read and update. A storage for features, support read and update.
feature_keys : (str, str, str) node_feature_keys : List[str] or Dict[str, List[str]]
Features need to be read, with each feature being uniquely identified Node features keys indicates the node features need to be read.
by a triplet '(domain, type_name, feature_name)'. - If `node_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `node_features` is a dictionary: The keys should be node type
and the values are lists of feature names.
edge_feature_keys : List[str] or Dict[str, List[str]]
Edge features name indicates the edge features need to be read.
- If `edge_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `edge_features` is a dictionary: The keys are edge types,
following the format 'str:str:str', and the values are lists of
feature names.
""" """
super().__init__(datapipe, self._read) super().__init__(datapipe, self._read)
self.feature_store = feature_store self.feature_store = feature_store
self.feature_keys = feature_keys self.node_feature_keys = node_feature_keys
self.edge_feature_keys = edge_feature_keys
def _read(self, data): def _read(self, data):
""" """
...@@ -40,38 +59,63 @@ class FeatureFetcher(Mapper): ...@@ -40,38 +59,63 @@ class FeatureFetcher(Mapper):
DataBlock DataBlock
An instance of 'DataBlock' filled with required features. An instance of 'DataBlock' filled with required features.
""" """
data.node_feature = {} data.node_features = {}
num_layer = len(data.sampled_subgraphs) if data.sampled_subgraphs else 0 num_layer = len(data.sampled_subgraphs) if data.sampled_subgraphs else 0
data.edge_feature = [{} for _ in range(num_layer)] data.edge_features = [{} for _ in range(num_layer)]
for key in self.feature_keys: is_heterogeneous = isinstance(
domain, type_name, feature_name = key self.node_feature_keys, Dict
if domain == "node" and data.input_nodes is not None: ) or isinstance(self.edge_feature_keys, Dict)
nodes = ( # Read Node features.
data.input_nodes if self.node_feature_keys and data.input_nodes is not None:
if not type_name if is_heterogeneous:
else data.input_nodes[type_name] for type_name, feature_names in self.node_feature_keys.items():
) nodes = data.input_nodes[type_name]
if nodes is not None: if nodes is None:
data.node_feature[ continue
for feature_name in feature_names:
data.node_features[
(type_name, feature_name) (type_name, feature_name)
] = self.feature_store.read( ] = self.feature_store.read(
domain, "node",
type_name, type_name,
feature_name, feature_name,
nodes, nodes,
) )
elif domain == "edge" and data.sampled_subgraphs is not None: else:
for i, subgraph in enumerate(data.sampled_subgraphs): for feature_name in self.node_feature_keys:
if subgraph.reverse_edge_ids is not None: data.node_features[feature_name] = self.feature_store.read(
edges = ( "node",
subgraph.reverse_edge_ids None,
if not type_name feature_name,
else subgraph.reverse_edge_ids.get(type_name, None) data.input_nodes,
) )
if edges is not None: # Read Edge features.
data.edge_feature[i][ if self.edge_feature_keys and data.sampled_subgraphs:
for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.reverse_edge_ids is None:
continue
if is_heterogeneous:
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = subgraph.reverse_edge_ids.get(type_name, None)
if edges is None:
continue
for feature_name in feature_names:
data.edge_features[i][
(type_name, feature_name) (type_name, feature_name)
] = self.feature_store.read( ] = self.feature_store.read(
domain, type_name, feature_name, edges "edge", type_name, feature_name, edges
)
else:
for feature_name in self.edge_feature_keys:
data.edge_features[i][
feature_name
] = self.feature_store.read(
"edge",
None,
feature_name,
subgraph.reverse_edge_ids,
) )
return data return data
...@@ -21,7 +21,7 @@ def test_FeatureFetcher_homo(): ...@@ -21,7 +21,7 @@ def test_FeatureFetcher_homo():
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block) data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts) sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, keys) fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
...@@ -54,14 +54,14 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -54,14 +54,14 @@ def test_FeatureFetcher_with_edges_homo():
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids) converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, keys) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
for data in fetcher_dp: for data in fetcher_dp:
assert data.node_feature[(None, "a")].size(0) == 2 assert data.node_features["a"].size(0) == 2
assert len(data.edge_feature) == 3 assert len(data.edge_features) == 3
for edge_feature in data.edge_feature: for edge_feature in data.edge_features:
assert edge_feature[(None, "b")].size(0) == 10 assert edge_feature["b"].size(0) == 10
def get_hetero_graph(): def get_hetero_graph():
...@@ -108,7 +108,9 @@ def test_FeatureFetcher_hetero(): ...@@ -108,7 +108,9 @@ def test_FeatureFetcher_hetero():
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block) data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts) sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, keys) fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
)
assert len(list(fetcher_dp)) == 3 assert len(list(fetcher_dp)) == 3
...@@ -148,11 +150,13 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -148,11 +150,13 @@ def test_FeatureFetcher_with_edges_hetero():
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids) converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, keys) fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
)
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
for data in fetcher_dp: for data in fetcher_dp:
assert data.node_feature[("n1", "a")].size(0) == 2 assert data.node_features[("n1", "a")].size(0) == 2
assert len(data.edge_feature) == 3 assert len(data.edge_features) == 3
for edge_feature in data.edge_feature: for edge_feature in data.edge_features:
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10 assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
...@@ -32,7 +32,7 @@ def test_DataLoader(): ...@@ -32,7 +32,7 @@ def test_DataLoader():
feature_fetcher = dgl.graphbolt.FeatureFetcher( feature_fetcher = dgl.graphbolt.FeatureFetcher(
subgraph_sampler, subgraph_sampler,
feature_store, feature_store,
keys, ["a", "b"],
) )
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
......
...@@ -32,7 +32,9 @@ def test_DataLoader(): ...@@ -32,7 +32,9 @@ def test_DataLoader():
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
feature_fetcher = dgl.graphbolt.FeatureFetcher( feature_fetcher = dgl.graphbolt.FeatureFetcher(
subgraph_sampler, feature_store, keys subgraph_sampler,
feature_store,
["a"],
) )
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment