"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "308bd6f5b245929211f365396ca2007ac151b8e7"
Unverified Commit 97697055 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Grapbolt]Add unified data structure (#6140)

parent 7467da9a
...@@ -17,6 +17,9 @@ from .subgraph_sampler import * ...@@ -17,6 +17,9 @@ from .subgraph_sampler import *
from .sampled_subgraph import * from .sampled_subgraph import *
from .data_format import * from .data_format import *
from .negative_sampler import * from .negative_sampler import *
from .data_block import *
from .node_classification_block import *
from .link_prediction_block import *
from .utils import unique_and_compact_node_pairs from .utils import unique_and_compact_node_pairs
......
"""Unified data structure for input and ouput of all the stages in loading process."""
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import torch
from .sampled_subgraph import SampledSubgraph
__all__ = ["DataBlock"]
@dataclass
class DataBlock:
r"""A composite data class for data structure in the graphbolt. It is
designed to facilitate the exchange of data among different components
involved in processing data. The purpose of this class is to unify the
representation of input and output data across different stages, ensuring
consistency and ease of use throughout the loading process."""
sampled_subgraphs: List[SampledSubgraph] = None
"""
A list of 'SampledSubgraph's, each one corresponding to one layer,
representing a subset of a larger graph structure.
"""
node_feature: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""A representation of node feature.
- If `node_feature` is a tensor: It indicates the graph is homogeneous.
- If `node_feature` is a dictionary: The keys should be node type and the
value should be corresponding node feature or embedding.
"""
edge_feature: List[
Union[torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]]
] = None
"""A representation of edge feature corresponding to 'sampled_subgraphs'.
- If `edge_feature` is a tensor: It indicates the graph is homogeneous.
- If `edge_feature` is a dictionary: The keys should be edge type and the
value should be corresponding edge feature or embedding.
"""
input_nodes: Union[
torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]
] = None
"""A representation of input nodes in the outermost layer. Conatins all nodes
in the 'sampled_subgraphs'.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
"""Unified data structure for input and ouput of all the stages in loading
process, especially for edge level task."""
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
from .data_block import DataBlock
@dataclass
class LinkPredictionBlock(DataBlock):
r"""A subclass of 'UnifiedDataStruct', specialized for handling edge level
tasks."""
node_pair: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pair` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
- If `node_pair` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
type.
"""
label: Union[torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]] = None
"""
Labels associated with the link prediction task.
- If `label` is a tensor: It indicates a homogeneous graph. The value are
edge labels corresponding to given 'node_pair'.
- If `label` is a dictionary: The keys should be edge type, and the value
should correspond to given 'node_pair'.
"""
negative_head: Union[
torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]
] = None
"""
Representation of negative samples for the head nodes in the link
prediction task.
- If `negative_head` is a tensor: It indicates a homogeneous graph.
- If `negative_head` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
negative_tail: Union[
torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]
] = None
"""
Representation of negative samples for the tail nodes in the link
prediction task.
- If `negative_tail` is a tensor: It indicates a homogeneous graph.
- If `negative_tail` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
compacted_node_pair: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of compacted node pairs corresponding to 'node_pair', where
all node ids inside are compacted.
"""
compacted_negative_head: Union[
torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]
] = None
"""
Representation of compacted nodes corresponding to 'negative_head', where
all node ids inside are compacted.
"""
compacted_negative_tail: Union[
torch.Tensor, Dict[Tuple[str, str, str], torch.Tensor]
] = None
"""
Representation of compacted nodes corresponding to 'negative_tail', where
all node ids inside are compacted.
"""
"""Unified data structure for input and ouput of all the stages in loading
process, especially for node level task."""
from dataclasses import dataclass
from typing import Dict, Union
import torch
from .data_block import DataBlock
@dataclass
class NodeClassificationBlock(DataBlock):
r"""A subclass of 'UnifiedDataStruct', specialized for handling node level
tasks."""
seed_node: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of seed nodes used for sampling in the graph.
- If `seed_node` is a tensor: It indicates the graph is homogeneous.
- If `seed_node` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous.
- If `label` is a dictionary: The keys should be node type and the
value should be corresponding node labels to given 'seed_node'.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment