Unverified Commit 219c9f1a authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Refator feature fetcher (#6245)

parent 155608d3
......@@ -23,18 +23,25 @@ class DataBlock:
representing a subset of a larger graph structure.
"""
node_feature: Dict[Tuple[str, str], torch.Tensor] = None
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
] = None
"""A representation of node features.
Keys are tuples of '(node_type, feature_name)' and the values are
corresponding features. Note that for a homogeneous graph, where there are
no node types, 'node_type' should be None.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
"""
edge_feature: List[Dict[Tuple[str, str], torch.Tensor]] = None
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
] = None
"""Edge features associated with the 'sampled_subgraphs'.
The keys are tuples in the format '(edge_type, feature_name)', and the
values represent the corresponding features. In the case of a homogeneous
graph where no edge types exist, 'edge_type' should be set to None.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is single
string of format 'str:str:str'.
"""
input_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
......
"""Feature fetchers"""
from typing import Dict
from torchdata.datapipes.iter import Mapper
class FeatureFetcher(Mapper):
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
def __init__(self, datapipe, feature_store, feature_keys):
def __init__(
self,
datapipe,
feature_store,
node_feature_keys=None,
edge_feature_keys=None,
):
"""
Initlization for a feature fetcher.
......@@ -16,13 +24,24 @@ class FeatureFetcher(Mapper):
The datapipe.
feature_store : FeatureStore
A storage for features, support read and update.
feature_keys : (str, str, str)
Features need to be read, with each feature being uniquely identified
by a triplet '(domain, type_name, feature_name)'.
node_feature_keys : List[str] or Dict[str, List[str]]
Node features keys indicates the node features need to be read.
- If `node_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `node_features` is a dictionary: The keys should be node type
and the values are lists of feature names.
edge_feature_keys : List[str] or Dict[str, List[str]]
Edge features name indicates the edge features need to be read.
- If `edge_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `edge_features` is a dictionary: The keys are edge types,
following the format 'str:str:str', and the values are lists of
feature names.
"""
super().__init__(datapipe, self._read)
self.feature_store = feature_store
self.feature_keys = feature_keys
self.node_feature_keys = node_feature_keys
self.edge_feature_keys = edge_feature_keys
def _read(self, data):
"""
......@@ -40,38 +59,63 @@ class FeatureFetcher(Mapper):
DataBlock
An instance of 'DataBlock' filled with required features.
"""
data.node_feature = {}
data.node_features = {}
num_layer = len(data.sampled_subgraphs) if data.sampled_subgraphs else 0
data.edge_feature = [{} for _ in range(num_layer)]
for key in self.feature_keys:
domain, type_name, feature_name = key
if domain == "node" and data.input_nodes is not None:
nodes = (
data.input_nodes
if not type_name
else data.input_nodes[type_name]
)
if nodes is not None:
data.node_feature[
(type_name, feature_name)
] = self.feature_store.read(
domain,
type_name,
data.edge_features = [{} for _ in range(num_layer)]
is_heterogeneous = isinstance(
self.node_feature_keys, Dict
) or isinstance(self.edge_feature_keys, Dict)
# Read Node features.
if self.node_feature_keys and data.input_nodes is not None:
if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items():
nodes = data.input_nodes[type_name]
if nodes is None:
continue
for feature_name in feature_names:
data.node_features[
(type_name, feature_name)
] = self.feature_store.read(
"node",
type_name,
feature_name,
nodes,
)
else:
for feature_name in self.node_feature_keys:
data.node_features[feature_name] = self.feature_store.read(
"node",
None,
feature_name,
nodes,
data.input_nodes,
)
elif domain == "edge" and data.sampled_subgraphs is not None:
for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None:
edges = (
subgraph.reverse_edge_ids
if not type_name
else subgraph.reverse_edge_ids.get(type_name, None)
)
if edges is not None:
data.edge_feature[i][
# Read Edge features.
if self.edge_feature_keys and data.sampled_subgraphs:
for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.reverse_edge_ids is None:
continue
if is_heterogeneous:
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = subgraph.reverse_edge_ids.get(type_name, None)
if edges is None:
continue
for feature_name in feature_names:
data.edge_features[i][
(type_name, feature_name)
] = self.feature_store.read(
domain, type_name, feature_name, edges
"edge", type_name, feature_name, edges
)
else:
for feature_name in self.edge_feature_keys:
data.edge_features[i][
feature_name
] = self.feature_store.read(
"edge",
None,
feature_name,
subgraph.reverse_edge_ids,
)
return data
......@@ -21,7 +21,7 @@ def test_FeatureFetcher_homo():
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, keys)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
......@@ -54,14 +54,14 @@ def test_FeatureFetcher_with_edges_homo():
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, keys)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
for data in fetcher_dp:
assert data.node_feature[(None, "a")].size(0) == 2
assert len(data.edge_feature) == 3
for edge_feature in data.edge_feature:
assert edge_feature[(None, "b")].size(0) == 10
assert data.node_features["a"].size(0) == 2
assert len(data.edge_features) == 3
for edge_feature in data.edge_features:
assert edge_feature["b"].size(0) == 10
def get_hetero_graph():
......@@ -108,7 +108,9 @@ def test_FeatureFetcher_hetero():
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, keys)
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
)
assert len(list(fetcher_dp)) == 3
......@@ -148,11 +150,13 @@ def test_FeatureFetcher_with_edges_hetero():
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, keys)
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
)
assert len(list(fetcher_dp)) == 5
for data in fetcher_dp:
assert data.node_feature[("n1", "a")].size(0) == 2
assert len(data.edge_feature) == 3
for edge_feature in data.edge_feature:
assert data.node_features[("n1", "a")].size(0) == 2
assert len(data.edge_features) == 3
for edge_feature in data.edge_features:
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
......@@ -32,7 +32,7 @@ def test_DataLoader():
feature_fetcher = dgl.graphbolt.FeatureFetcher(
subgraph_sampler,
feature_store,
keys,
["a", "b"],
)
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
......
......@@ -32,7 +32,9 @@ def test_DataLoader():
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
feature_fetcher = dgl.graphbolt.FeatureFetcher(
subgraph_sampler, feature_store, keys
subgraph_sampler,
feature_store,
["a"],
)
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment