"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8331da46837be40f96fbd24de6a6fb2da28acd11"
Unverified Commit 845864d2 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `SubgraphSampler` to support `seeds`. (#7049)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ee8b7b39
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict from typing import Dict
import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
...@@ -69,10 +70,16 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -69,10 +70,16 @@ class SubgraphSampler(MiniBatchTransformer):
seeds_timestamp = ( seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None minibatch.timestamp if hasattr(minibatch, "timestamp") else None
) )
elif minibatch.seeds is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_seeds,
) = SubgraphSampler._seeds_preprocess(minibatch)
else: else:
raise ValueError( raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or " f"Invalid minibatch {minibatch}: One of `node_pairs`, "
"`seed_nodes` should have a value." "`seed_nodes` and `seeds` should have a value."
) )
minibatch._seed_nodes = seeds minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp minibatch._seeds_timestamp = seeds_timestamp
...@@ -226,6 +233,116 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -226,6 +233,116 @@ class SubgraphSampler(MiniBatchTransformer):
""" """
return datapipe.transform(self._sample) return datapipe.transform(self._sample)
@staticmethod
def _seeds_preprocess(minibatch):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
compacting seeds based on their types and timestamps.
Parameters
----------
minibatch: MiniBatch
The minibatch.
Returns
-------
unique_seeds: torch.Tensor or Dict[str, torch.Tensor]
A tensor or a dictionary of tensors representing the unique seeds.
In heterogeneous graphs, seeds are returned for each node type.
nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor]
Containing timestamps for each seed. This is only returned if
`minibatch` includes timestamps and the graph is temporal.
compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
"""
use_timestamp = hasattr(minibatch, "timestamp")
seeds = minibatch.seeds
is_heterogeneous = isinstance(seeds, Dict)
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, pair in seeds.items():
assert pair.ndim == 1 or (
pair.ndim == 2 and pair.shape[1] == 2
), (
"Only tensor with shape 1*N and N*2 is "
+ f"supported now, but got {pair.shape}."
)
ntypes = etype[:].split(":")[::2]
pair = pair.view(pair.shape[0], -1)
if use_timestamp:
negative_ratio = (
pair.shape[0] // minibatch.timestamp[etype].shape[0] - 1
)
neg_timestamp = minibatch.timestamp[
etype
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(pair[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[etype]
)
nodes_timestamp[ntype].append(neg_timestamp)
# Unique and compact the collected nodes.
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = compact_temporal_nodes(nodes, nodes_timestamp)
else:
unique_seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for etype, pair in seeds.items():
if pair.ndim == 1:
compacted_seeds[etype] = compacted[etype].pop(0)
else:
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
else:
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
negative_ratio = (
seeds.shape[0] // minibatch.timestamp.shape[0] - 1
)
neg_timestamp = minibatch.timestamp.repeat_interleave(
negative_ratio
)
seeds_timestamp = torch.cat(
(minibatch.timestamp, neg_timestamp)
)
nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)]
# Unique and compact the collected nodes.
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = compact_temporal_nodes(nodes, nodes_timestamp)
else:
unique_seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect.
compacted_seeds = compacted[0].view(seeds.shape)
return (
unique_seeds,
nodes_timestamp,
compacted_seeds,
)
def sample_subgraphs(self, seeds, seeds_timestamp): def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds, possibly with temporal constraints. """Sample subgraphs from the given seeds, possibly with temporal constraints.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment