Unverified Commit 911c4aba authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use str etype in sampler (#6232)

parent 0be8d8a7
......@@ -11,6 +11,7 @@ import torch
from ...base import ETYPE
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple
from .sampled_subgraph_impl import SampledSubgraphImpl
......@@ -505,11 +506,11 @@ class CSCSamplingGraph:
Parameters
----------
edge_type: Tuple[str]
edge_type: str
The type of edges in the provided node_pairs. Any negative edges
sampled will also have the same type. If set to None, it will be
considered as a homogeneous graph.
node_pairs : Tuple[Tensor]
node_pairs : Tuple[Tensor, Tensor]
A tuple of two 1D tensors that represent the source and destination
of positive edges, with 'positive' indicating that these edges are
present in the graph. It's important to note that within the
......@@ -520,7 +521,7 @@ class CSCSamplingGraph:
Returns
-------
Tuple[Tensor]
Tuple[Tensor, Tensor]
A tuple consisting of two 1D tensors represents the source and
destination of negative edges. In the context of a heterogeneous
graph, both the input nodes and the selected nodes are represented
......@@ -528,12 +529,12 @@ class CSCSamplingGraph:
`edge_type`. Note that negative refers to false negatives, which
means the edge could be present or not present in the graph.
"""
if edge_type:
if edge_type is not None:
assert (
self.node_type_offset is not None
), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type."
_, _, dst_node_type = edge_type
_, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
max_node_id = (
self.node_type_offset[dst_node_type_id + 1]
......
......@@ -80,16 +80,16 @@ class NegativeSampler(Mapper):
Parameters
----------
node_pairs : Tuple[Tensor]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
etype : (str, str, str)
node_pairs : Tuple[Tensor, Tensor]
A tuple of tensors that represent source-destination node pairs of
positive edges, where positive means the edge must exist in the
graph.
etype : str
Canonical edge type.
Returns
-------
Tuple[Tensor]
Tuple[Tensor, Tensor]
A collection of negative node pairs.
"""
raise NotImplementedError
......@@ -102,14 +102,16 @@ class NegativeSampler(Mapper):
data : LinkPredictionBlock
The input data, which contains positive node pairs, will be filled
with negative information in this function.
neg_pairs : Tuple[Tensor]
neg_pairs : Tuple[Tensor, Tensor]
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
the graph.
etype : (str, str, str)
etype : str
Canonical edge type.
"""
pos_src, pos_dst = data.node_pair[etype] if etype else data.node_pair
pos_src, pos_dst = (
data.node_pair[etype] if etype is not None else data.node_pair
)
neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_label = torch.ones_like(pos_src)
......@@ -117,7 +119,7 @@ class NegativeSampler(Mapper):
src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label])
if etype:
if etype is not None:
data.node_pair[etype] = (src, dst)
data.label[etype] = label
else:
......@@ -141,7 +143,7 @@ class NegativeSampler(Mapper):
raise TypeError(
f"Unsupported output format {self.output_format}."
)
if etype:
if etype is not None:
data.negative_head[etype] = neg_src
data.negative_tail[etype] = neg_dst
else:
......
......@@ -5,6 +5,7 @@ from typing import Dict
from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple
from .data_block import LinkPredictionBlock, NodeClassificationBlock
from .utils import unique_and_compact
......@@ -51,14 +52,17 @@ class SubgraphSampler(Mapper):
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
for (src_type, _, dst_type), (src, dst) in node_pair.items():
for etype, (src, dst) in node_pair.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src)
nodes[dst_type].append(dst)
if has_neg_src:
for (src_type, _, _), src in neg_src.items():
for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1))
if has_neg_dst:
for (_, _, dst_type), dst in neg_dst.items():
for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1))
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
......@@ -69,16 +73,18 @@ class SubgraphSampler(Mapper):
) = ({}, {}, {})
# Map back in same order as collect.
for etype, _ in node_pair.items():
src_type, _, dst_type = etype
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_node_pair[etype] = (src, dst)
if has_neg_src:
for etype, _ in neg_src.items():
compacted_negative_head[etype] = compacted[etype[0]].pop(0)
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_head[etype] = compacted[src_type].pop(0)
if has_neg_dst:
for etype, _ in neg_dst.items():
compacted_negative_tail[etype] = compacted[etype[2]].pop(0)
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_tail[etype] = compacted[dst_type].pop(0)
else:
# Collect nodes from all types of input.
nodes = list(node_pair)
......
......@@ -184,13 +184,13 @@ def test_NegativeSampler_Hetero_Data(format):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
("n1", "e1", "n2"): gb.ItemSet(
"n1:e1:n2": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
),
("n2", "e2", "n1"): gb.ItemSet(
"n2:e2:n1": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
......
......@@ -104,13 +104,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
("n1", "e1", "n2"): gb.ItemSet(
"n1:e1:n2": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
),
("n2", "e2", "n1"): gb.ItemSet(
"n2:e2:n1": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
......@@ -142,13 +142,13 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
("n1", "e1", "n2"): gb.ItemSet(
"n1:e1:n2": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
),
("n2", "e2", "n1"): gb.ItemSet(
"n2:e2:n1": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment