Unverified Commit 81ac9d27 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add MiniBatchBase (#6531)

parent c08f77bf
......@@ -4,6 +4,8 @@ from typing import Dict
from torch.utils.data import functional_datapipe
from .base import etype_tuple_to_str
from .minibatch_transformer import MiniBatchTransformer
......@@ -67,21 +69,22 @@ class FeatureFetcher(MiniBatchTransformer):
MiniBatch
An instance of :class:`MiniBatch` filled with required features.
"""
data.node_features = {}
num_layer = len(data.sampled_subgraphs) if data.sampled_subgraphs else 0
data.edge_features = [{} for _ in range(num_layer)]
node_features = {}
num_layers = data.num_layers()
edge_features = [{} for _ in range(num_layers)]
is_heterogeneous = isinstance(
self.node_feature_keys, Dict
) or isinstance(self.edge_feature_keys, Dict)
# Read Node features.
if self.node_feature_keys and data.input_nodes is not None:
input_nodes = data.node_ids()
if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items():
nodes = data.input_nodes[type_name]
nodes = input_nodes[type_name]
if nodes is None:
continue
for feature_name in feature_names:
data.node_features[
node_features[
(type_name, feature_name)
] = self.feature_store.read(
"node",
......@@ -91,39 +94,49 @@ class FeatureFetcher(MiniBatchTransformer):
)
else:
for feature_name in self.node_feature_keys:
data.node_features[feature_name] = self.feature_store.read(
node_features[feature_name] = self.feature_store.read(
"node",
None,
feature_name,
data.input_nodes,
input_nodes,
)
# Read Edge features.
if self.edge_feature_keys and data.sampled_subgraphs:
for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.original_edge_ids is None:
if self.edge_feature_keys and num_layers > 0:
for i in range(num_layers):
original_edge_ids = data.edge_ids(i)
if original_edge_ids is None:
continue
if is_heterogeneous:
# Convert edge type to string for DGLMiniBatch.
original_edge_ids = {
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key: value
for key, value in original_edge_ids.items()
}
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = subgraph.original_edge_ids.get(type_name, None)
edges = original_edge_ids.get(type_name, None)
if edges is None:
continue
for feature_name in feature_names:
data.edge_features[i][
edge_features[i][
(type_name, feature_name)
] = self.feature_store.read(
"edge", type_name, feature_name, edges
)
else:
for feature_name in self.edge_feature_keys:
data.edge_features[i][
edge_features[i][
feature_name
] = self.feature_store.read(
"edge",
None,
feature_name,
subgraph.original_edge_ids,
original_edge_ids,
)
data.set_node_features(node_features)
data.set_edge_features(edge_features)
return data
"""Unified data structure for input and ouput of all the stages in loading process."""
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
......@@ -16,7 +16,53 @@ __all__ = ["DGLMiniBatch", "MiniBatch"]
@dataclass
class DGLMiniBatch:
class MiniBatchBase(object):
"""Base class for `MiniBatch` and `DGLMiniBatch`."""
def node_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the MiniBatch.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
raise NotImplementedError
def num_layers(self) -> int:
"""Return the number of layers."""
raise NotImplementedError
def set_node_features(
self,
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
],
) -> None:
"""Set node features."""
raise NotImplementedError
def set_edge_features(
self,
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
],
) -> None:
"""Set edge features."""
raise NotImplementedError
def edge_ids(
self, layer_id: int
) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
"""Get the edge ids of a layer."""
raise NotImplementedError
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy MiniBatch to the specified device."""
raise NotImplementedError
@dataclass
class DGLMiniBatch(MiniBatchBase):
r"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library."""
......@@ -99,6 +145,47 @@ class DGLMiniBatch:
def __repr__(self) -> str:
return _dgl_minibatch_str(self)
def node_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `blocks`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return self.input_nodes
def num_layers(self) -> int:
"""Return the number of layers."""
if self.blocks is None:
return 0
return len(self.blocks)
def edge_ids(
self, layer_id: int
) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]:
"""Get edge ids of a layer."""
if dgl.EID not in self.blocks[layer_id].edata:
return None
return self.blocks[layer_id].edata[dgl.EID]
def set_node_features(
self,
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
],
) -> None:
"""Set node features."""
self.node_features = node_features
def set_edge_features(
self,
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
],
) -> None:
"""Set edge features."""
self.edge_features = edge_features
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
......@@ -236,6 +323,45 @@ class MiniBatch:
def __repr__(self) -> str:
return _minibatch_str(self)
def node_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `sampled_subgraphs`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return self.input_nodes
def num_layers(self) -> int:
"""Return the number of layers."""
if self.sampled_subgraphs is None:
return 0
return len(self.sampled_subgraphs)
def edge_ids(
self, layer_id: int
) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
"""Get the edge ids of a layer."""
return self.sampled_subgraphs[layer_id].original_edge_ids
def set_node_features(
self,
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
],
) -> None:
"""Set node features."""
self.node_features = node_features
def set_edge_features(
self,
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
],
) -> None:
"""Set edge features."""
self.edge_features = edge_features
def _to_dgl_blocks(self):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
......
......@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .minibatch import MiniBatch
from .minibatch import DGLMiniBatch, MiniBatch
__all__ = [
"MiniBatchTransformer",
......@@ -37,7 +37,7 @@ class MiniBatchTransformer(Mapper):
def _transformer(self, minibatch):
minibatch = self.transformer(minibatch)
assert isinstance(
minibatch, MiniBatch
minibatch, (MiniBatch, DGLMiniBatch)
), "The transformer output should be an instance of MiniBatch"
return minibatch
......
import random
from enum import Enum
import dgl.graphbolt as gb
import gb_test_utils
import pytest
import torch
from torchdata.datapipes.iter import Mapper
def test_FeatureFetcher_invoke():
class MiniBatchType(Enum):
MiniBatch = 1
DGLMiniBatch = 2
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_invoke(minibatch_type):
# Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor(
......@@ -29,6 +39,9 @@ def test_FeatureFetcher_invoke():
# Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
datapipe = datapipe.to_dgl()
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5
......@@ -39,7 +52,10 @@ def test_FeatureFetcher_invoke():
assert len(list(datapipe)) == 5
def test_FeatureFetcher_homo():
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_homo(minibatch_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
......@@ -59,12 +75,17 @@ def test_FeatureFetcher_homo():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
sampler_dp = sampler_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
def test_FeatureFetcher_with_edges_homo():
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_homo(minibatch_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
......@@ -76,9 +97,12 @@ def test_FeatureFetcher_with_edges_homo():
def add_node_and_edge_ids(seeds):
subgraphs = []
for _ in range(3):
range_tensor = torch.arange(10)
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
node_pairs=(range_tensor, range_tensor),
original_column_node_ids=range_tensor,
original_row_node_ids=range_tensor,
original_edge_ids=torch.randint(
0, graph.total_num_edges, (10,)
),
......@@ -96,6 +120,8 @@ def test_FeatureFetcher_with_edges_homo():
itemset = gb.ItemSet(torch.arange(10))
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
if minibatch_type == MiniBatchType.DGLMiniBatch:
converter_dp = converter_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
......@@ -128,7 +154,10 @@ def get_hetero_graph():
)
def test_FeatureFetcher_hetero():
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_hetero(minibatch_type):
graph = get_hetero_graph()
a = torch.tensor([[random.randint(0, 10)] for _ in range(2)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(3)])
......@@ -149,6 +178,8 @@ def test_FeatureFetcher_hetero():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
sampler_dp = sampler_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
)
......@@ -156,7 +187,10 @@ def test_FeatureFetcher_hetero():
assert len(list(fetcher_dp)) == 3
def test_FeatureFetcher_with_edges_hetero():
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_hetero(minibatch_type):
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
......@@ -166,10 +200,29 @@ def test_FeatureFetcher_with_edges_hetero():
"n1:e1:n2": torch.randint(0, 50, (10,)),
"n2:e2:n1": torch.randint(0, 50, (10,)),
}
original_column_node_ids = {
"n1": torch.randint(0, 20, (10,)),
"n2": torch.randint(0, 20, (10,)),
}
original_row_node_ids = {
"n1": torch.randint(0, 20, (10,)),
"n2": torch.randint(0, 20, (10,)),
}
for _ in range(3):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
node_pairs={
"n1:e1:n2": (
torch.arange(10),
torch.arange(10),
),
"n2:e2:n1": (
torch.arange(10),
torch.arange(10),
),
},
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
)
......@@ -189,6 +242,8 @@ def test_FeatureFetcher_with_edges_hetero():
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
if minibatch_type == MiniBatchType.DGLMiniBatch:
converter_dp = converter_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment