"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3cf42fd199259c061bb7749cb65757ba0d8a5b67"
Unverified Commit 911c4aba authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use str etype in sampler (#6232)

parent 0be8d8a7
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from ...base import ETYPE from ...base import ETYPE
from ...convert import to_homogeneous from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple
from .sampled_subgraph_impl import SampledSubgraphImpl from .sampled_subgraph_impl import SampledSubgraphImpl
...@@ -505,11 +506,11 @@ class CSCSamplingGraph: ...@@ -505,11 +506,11 @@ class CSCSamplingGraph:
Parameters Parameters
---------- ----------
edge_type: Tuple[str] edge_type: str
The type of edges in the provided node_pairs. Any negative edges The type of edges in the provided node_pairs. Any negative edges
sampled will also have the same type. If set to None, it will be sampled will also have the same type. If set to None, it will be
considered as a homogeneous graph. considered as a homogeneous graph.
node_pairs : Tuple[Tensor] node_pairs : Tuple[Tensor, Tensor]
A tuple of two 1D tensors that represent the source and destination A tuple of two 1D tensors that represent the source and destination
of positive edges, with 'positive' indicating that these edges are of positive edges, with 'positive' indicating that these edges are
present in the graph. It's important to note that within the present in the graph. It's important to note that within the
...@@ -520,7 +521,7 @@ class CSCSamplingGraph: ...@@ -520,7 +521,7 @@ class CSCSamplingGraph:
Returns Returns
------- -------
Tuple[Tensor] Tuple[Tensor, Tensor]
A tuple consisting of two 1D tensors represents the source and A tuple consisting of two 1D tensors represents the source and
destination of negative edges. In the context of a heterogeneous destination of negative edges. In the context of a heterogeneous
graph, both the input nodes and the selected nodes are represented graph, both the input nodes and the selected nodes are represented
...@@ -528,12 +529,12 @@ class CSCSamplingGraph: ...@@ -528,12 +529,12 @@ class CSCSamplingGraph:
`edge_type`. Note that negative refers to false negatives, which `edge_type`. Note that negative refers to false negatives, which
means the edge could be present or not present in the graph. means the edge could be present or not present in the graph.
""" """
if edge_type: if edge_type is not None:
assert ( assert (
self.node_type_offset is not None self.node_type_offset is not None
), "The 'node_type_offset' array is necessary for performing \ ), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type." negative sampling by edge type."
_, _, dst_node_type = edge_type _, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type] dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
max_node_id = ( max_node_id = (
self.node_type_offset[dst_node_type_id + 1] self.node_type_offset[dst_node_type_id + 1]
......
...@@ -80,16 +80,16 @@ class NegativeSampler(Mapper): ...@@ -80,16 +80,16 @@ class NegativeSampler(Mapper):
Parameters Parameters
---------- ----------
node_pairs : Tuple[Tensor] node_pairs : Tuple[Tensor, Tensor]
A tuple of tensors or a dictionary represents source-destination A tuple of tensors that represent source-destination node pairs of
node pairs of positive edges, where positive means the edge must positive edges, where positive means the edge must exist in the
exist in the graph. graph.
etype : (str, str, str) etype : str
Canonical edge type. Canonical edge type.
Returns Returns
------- -------
Tuple[Tensor] Tuple[Tensor, Tensor]
A collection of negative node pairs. A collection of negative node pairs.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -102,14 +102,16 @@ class NegativeSampler(Mapper): ...@@ -102,14 +102,16 @@ class NegativeSampler(Mapper):
data : LinkPredictionBlock data : LinkPredictionBlock
The input data, which contains positive node pairs, will be filled The input data, which contains positive node pairs, will be filled
with negative information in this function. with negative information in this function.
neg_pairs : Tuple[Tensor] neg_pairs : Tuple[Tensor, Tensor]
A tuple of tensors represents source-destination node pairs of A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in negative edges, where negative means the edge may not exist in
the graph. the graph.
etype : (str, str, str) etype : str
Canonical edge type. Canonical edge type.
""" """
pos_src, pos_dst = data.node_pair[etype] if etype else data.node_pair pos_src, pos_dst = (
data.node_pair[etype] if etype is not None else data.node_pair
)
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_label = torch.ones_like(pos_src) pos_label = torch.ones_like(pos_src)
...@@ -117,7 +119,7 @@ class NegativeSampler(Mapper): ...@@ -117,7 +119,7 @@ class NegativeSampler(Mapper):
src = torch.cat([pos_src, neg_src]) src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst]) dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label]) label = torch.cat([pos_label, neg_label])
if etype: if etype is not None:
data.node_pair[etype] = (src, dst) data.node_pair[etype] = (src, dst)
data.label[etype] = label data.label[etype] = label
else: else:
...@@ -141,7 +143,7 @@ class NegativeSampler(Mapper): ...@@ -141,7 +143,7 @@ class NegativeSampler(Mapper):
raise TypeError( raise TypeError(
f"Unsupported output format {self.output_format}." f"Unsupported output format {self.output_format}."
) )
if etype: if etype is not None:
data.negative_head[etype] = neg_src data.negative_head[etype] = neg_src
data.negative_tail[etype] = neg_dst data.negative_tail[etype] = neg_dst
else: else:
......
...@@ -5,6 +5,7 @@ from typing import Dict ...@@ -5,6 +5,7 @@ from typing import Dict
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple
from .data_block import LinkPredictionBlock, NodeClassificationBlock from .data_block import LinkPredictionBlock, NodeClassificationBlock
from .utils import unique_and_compact from .utils import unique_and_compact
...@@ -51,14 +52,17 @@ class SubgraphSampler(Mapper): ...@@ -51,14 +52,17 @@ class SubgraphSampler(Mapper):
if is_heterogeneous: if is_heterogeneous:
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = defaultdict(list) nodes = defaultdict(list)
for (src_type, _, dst_type), (src, dst) in node_pair.items(): for etype, (src, dst) in node_pair.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src) nodes[src_type].append(src)
nodes[dst_type].append(dst) nodes[dst_type].append(dst)
if has_neg_src: if has_neg_src:
for (src_type, _, _), src in neg_src.items(): for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1)) nodes[src_type].append(src.view(-1))
if has_neg_dst: if has_neg_dst:
for (_, _, dst_type), dst in neg_dst.items(): for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1)) nodes[dst_type].append(dst.view(-1))
# Unique and compact the collected nodes. # Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes) seeds, compacted = unique_and_compact(nodes)
...@@ -69,16 +73,18 @@ class SubgraphSampler(Mapper): ...@@ -69,16 +73,18 @@ class SubgraphSampler(Mapper):
) = ({}, {}, {}) ) = ({}, {}, {})
# Map back in same order as collect. # Map back in same order as collect.
for etype, _ in node_pair.items(): for etype, _ in node_pair.items():
src_type, _, dst_type = etype src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0) src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0) dst = compacted[dst_type].pop(0)
compacted_node_pair[etype] = (src, dst) compacted_node_pair[etype] = (src, dst)
if has_neg_src: if has_neg_src:
for etype, _ in neg_src.items(): for etype, _ in neg_src.items():
compacted_negative_head[etype] = compacted[etype[0]].pop(0) src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_head[etype] = compacted[src_type].pop(0)
if has_neg_dst: if has_neg_dst:
for etype, _ in neg_dst.items(): for etype, _ in neg_dst.items():
compacted_negative_tail[etype] = compacted[etype[2]].pop(0) _, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_tail[etype] = compacted[dst_type].pop(0)
else: else:
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = list(node_pair) nodes = list(node_pair)
......
...@@ -184,13 +184,13 @@ def test_NegativeSampler_Hetero_Data(format): ...@@ -184,13 +184,13 @@ def test_NegativeSampler_Hetero_Data(format):
graph = get_hetero_graph() graph = get_hetero_graph()
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
("n1", "e1", "n2"): gb.ItemSet( "n1:e1:n2": gb.ItemSet(
( (
torch.LongTensor([0, 0, 1, 1]), torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]), torch.LongTensor([0, 2, 0, 1]),
) )
), ),
("n2", "e2", "n1"): gb.ItemSet( "n2:e2:n1": gb.ItemSet(
( (
torch.LongTensor([0, 0, 1, 1, 2, 2]), torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]), torch.LongTensor([0, 1, 1, 0, 0, 1]),
......
...@@ -104,13 +104,13 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -104,13 +104,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph() graph = get_hetero_graph()
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
("n1", "e1", "n2"): gb.ItemSet( "n1:e1:n2": gb.ItemSet(
( (
torch.LongTensor([0, 0, 1, 1]), torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]), torch.LongTensor([0, 2, 0, 1]),
) )
), ),
("n2", "e2", "n1"): gb.ItemSet( "n2:e2:n1": gb.ItemSet(
( (
torch.LongTensor([0, 0, 1, 1, 2, 2]), torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]), torch.LongTensor([0, 1, 1, 0, 0, 1]),
...@@ -142,13 +142,13 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor): ...@@ -142,13 +142,13 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
graph = get_hetero_graph() graph = get_hetero_graph()
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
("n1", "e1", "n2"): gb.ItemSet( "n1:e1:n2": gb.ItemSet(
( (
torch.LongTensor([0, 0, 1, 1]), torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]), torch.LongTensor([0, 2, 0, 1]),
) )
), ),
("n2", "e2", "n1"): gb.ItemSet( "n2:e2:n1": gb.ItemSet(
( (
torch.LongTensor([0, 0, 1, 1, 2, 2]), torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]), torch.LongTensor([0, 1, 1, 0, 0, 1]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment