Unverified Commit 97ed294d authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Delete DGLMiniBatch (#6760)

parent b483c26f
"""Unified data structure for input and ouput of all the stages in loading process."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union
import torch
import dgl
from dgl.heterograph import DGLBlock
from dgl.utils import recursive_apply
from .base import CSCFormatBase, etype_str_to_tuple
from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph
__all__ = ["DGLMiniBatch", "MiniBatch"]
@dataclass
class MiniBatchBase(object):
"""Base class for `MiniBatch` and `DGLMiniBatch`."""
def node_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the MiniBatch.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
raise NotImplementedError
def num_layers(self) -> int:
"""Return the number of layers."""
raise NotImplementedError
def set_node_features(
self,
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
],
) -> None:
"""Set node features."""
raise NotImplementedError
def set_edge_features(
self,
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
],
) -> None:
"""Set edge features."""
raise NotImplementedError
def edge_ids(
self, layer_id: int
) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
"""Get the edge ids of a layer."""
raise NotImplementedError
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy MiniBatch to the specified device."""
raise NotImplementedError
@dataclass
class DGLMiniBatch(MiniBatchBase):
r"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library."""
blocks: List[DGLBlock] = None
"""A list of 'DGLBlock's, each one corresponding to one layer, representing
a bipartite graph used for message passing.
"""
input_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""A representation of input nodes in the outermost layer. Conatins all
nodes in the 'blocks'.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
output_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""Representation of output nodes, usually also the seed nodes, used for
sampling in the graph.
- If `output_nodes` is a tensor: It indicates the graph is homogeneous.
- If `output_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
] = None
"""A representation of node features.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
"""
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
] = None
"""Edge features associated with the 'blocks'.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is a triplet
of format (str, str, str).
"""
labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""Labels associated with seed nodes / node pairs in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
are corresponding labels to given 'output_nodes' or 'node_pairs'.
- If `labels` is a dictionary: The keys are node or edge type and the value
should be corresponding labels to given 'output_nodes' or 'node_pairs'.
"""
positive_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""Representation of positive graphs used for evaluating or computing loss
in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
negative_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""Representation of negative graphs used for evaluating or computing loss in
link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
def __repr__(self) -> str:
return _dgl_minibatch_str(self)
def node_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `blocks`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return self.input_nodes
def num_layers(self) -> int:
"""Return the number of layers."""
if self.blocks is None:
return 0
return len(self.blocks)
def edge_ids(
self, layer_id: int
) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]:
"""Get edge ids of a layer."""
if dgl.EID not in self.blocks[layer_id].edata:
return None
return self.blocks[layer_id].edata[dgl.EID]
def set_node_features(
self,
node_features: Union[
Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]
],
) -> None:
"""Set node features."""
self.node_features = node_features
def set_edge_features(
self,
edge_features: List[
Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]
],
) -> None:
"""Set edge features."""
self.edge_features = edge_features
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: _to(x, device)
),
)
return self
__all__ = ["MiniBatch"]
@dataclass
......@@ -654,34 +462,3 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
final_str + f"{name}={_add_indent(val, len(name)+1)},\n" + " " * 10
)
return "MiniBatch(" + final_str[:-3] + ")"
def _dgl_minibatch_str(dglminibatch: DGLMiniBatch) -> str:
final_str = ""
# Get all attributes in the class except methods.
attributes = get_attributes(dglminibatch)
attributes.reverse()
# Insert key with its value into the string.
for name in attributes:
val = getattr(dglminibatch, name)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
# Let the variables in the list occupy one line each, and adjust the
# indentation on top of the original if the original data output has
# line feeds.
if isinstance(val, list):
val = [str(val_str) for val_str in val]
val = "[" + ",\n".join(val) + "]"
elif isinstance(val, tuple):
val = [str(val_str) for val_str in val]
val = "(" + ",\n".join(val) + ")"
else:
val = str(val)
final_str = (
final_str + f"{name}={_add_indent(val, len(name)+15)},\n" + " " * 13
)
return "DGLMiniBatch(" + final_str[:-3] + ")"
......@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .minibatch import DGLMiniBatch, MiniBatch
from .minibatch import MiniBatch
__all__ = [
"MiniBatchTransformer",
......@@ -37,6 +37,6 @@ class MiniBatchTransformer(Mapper):
def _transformer(self, minibatch):
minibatch = self.transformer(minibatch)
assert isinstance(
minibatch, (MiniBatch, DGLMiniBatch)
minibatch, (MiniBatch,)
), "The transformer output should be an instance of MiniBatch"
return minibatch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment