Unverified Commit f5981789 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Support temporal sampling in SubgraphSampler. (#6846)

parent d32a5980
......@@ -5,6 +5,7 @@ from .gpu_cached_feature import *
from .in_subgraph_sampler import *
from .legacy_dataset import *
from .neighbor_sampler import *
from .temporal_neighbor_sampler import *
from .ondisk_dataset import *
from .ondisk_metadata import *
from .sampled_subgraph_impl import *
......
......@@ -761,8 +761,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def temporal_sample_neighbors(
self,
nodes: torch.Tensor,
input_nodes_timestamp: torch.Tensor,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
input_nodes_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
......
......@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self.graph = graph
self.sampler = graph.in_subgraph
def sample_subgraphs(self, seeds, seeds_timestamp=None):
def sample_subgraphs(self, seeds, seeds_timestamp):
subgraph = self.sampler(seeds)
(
original_row_node_ids,
......
......@@ -117,7 +117,7 @@ class NeighborSampler(SubgraphSampler):
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors
def sample_subgraphs(self, seeds, seeds_timestamp=None):
def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs = []
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
......
......@@ -89,7 +89,10 @@ class TemporalNeighborSampler(SubgraphSampler):
self.edge_timestamp_attr_name = edge_timestamp_attr_name
self.sampler = graph.temporal_sample_neighbors
def sample_subgraphs(self, seeds, seeds_timestamp=None):
def sample_subgraphs(self, seeds, seeds_timestamp):
assert (
seeds_timestamp is not None
), "seeds_timestamp must be provided for temporal neighbor sampling."
subgraphs = []
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
......@@ -117,10 +120,10 @@ class TemporalNeighborSampler(SubgraphSampler):
original_row_node_ids,
compacted_csc_formats,
row_timestamps,
) = compact_csc_format(subgraph.node_pairs, seeds, seeds_timestamp)
) = compact_csc_format(subgraph.sampled_csc, seeds, seeds_timestamp)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats,
sampled_csc=compacted_csc_formats,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
......
......@@ -61,6 +61,61 @@ def unique_and_compact(
return unique_and_compact_per_type(nodes)
def compact_temporal_nodes(nodes, nodes_timestamp):
"""Compact a list of temporal nodes without unique.
Note that since there is no unique, the nodes and nodes_timestamp are simply
concatenated. And the compacted nodes are consecutive numbers starting from
0.
Parameters
----------
nodes : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of nodes for compacting.
the compact operator will be done per type
- If `nodes` is a list of tensor: All the tensors will compact together,
usually it is used for homogeneous graph.
- If `nodes` is a list of dictionary: The keys should be node type and
the values should be corresponding nodes, the compact will be done per
type, usually it is used for heterogeneous graph.
nodes_timestamp : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of timestamps for compacting.
Returns
-------
Tuple[nodes, nodes_timestamp, compacted_node_list]
The concatenated nodes and nodes_timestamp, and the compacted nodes list,
where IDs inside are replaced with compacted node IDs.
"""
def _compact_per_type(per_type_nodes, per_type_nodes_timestamp):
nums = [node.size(0) for node in per_type_nodes]
per_type_nodes = torch.cat(per_type_nodes)
per_type_nodes_timestamp = torch.cat(per_type_nodes_timestamp)
compacted_nodes = torch.arange(
0,
per_type_nodes.numel(),
dtype=per_type_nodes.dtype,
device=per_type_nodes.device,
)
compacted_nodes = list(compacted_nodes.split(nums))
return per_type_nodes, per_type_nodes_timestamp, compacted_nodes
if isinstance(nodes, dict):
ret_nodes, ret_timestamp, compacted = {}, {}, {}
for ntype, nodes_of_type in nodes.items():
(
ret_nodes[ntype],
ret_timestamp[ntype],
compacted[ntype],
) = _compact_per_type(nodes_of_type, nodes_timestamp[ntype])
return ret_nodes, ret_timestamp, compacted
else:
return _compact_per_type(nodes, nodes_timestamp)
def unique_and_compact_csc_formats(
csc_formats: Union[
Tuple[torch.Tensor, torch.Tensor],
......@@ -236,7 +291,8 @@ def compact_csc_format(
A tensor of original row node IDs (per type) of all nodes in the input.
The compacted CSC formats, where node IDs are replaced with mapped node
IDs ranging from 0 to N.
The source timestamps (per type) of all nodes in the input if `dst_timestamps` is given.
The source timestamps (per type) of all nodes in the input if
`dst_timestamps` is given.
Examples
--------
......@@ -318,8 +374,13 @@ def compact_csc_format(
src_timestamps = None
if has_timestamp:
src_timestamps = _broadcast_timestamps(
compacted_csc_formats, dst_timestamps
src_timestamps = torch.cat(
[
dst_timestamps,
_broadcast_timestamps(
compacted_csc_formats, dst_timestamps
),
]
)
else:
compacted_csc_formats = {}
......
......@@ -6,7 +6,7 @@ from typing import Dict
from torch.utils.data import functional_datapipe
from .base import etype_str_to_tuple
from .internal import unique_and_compact
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer
__all__ = [
......@@ -40,12 +40,16 @@ class SubgraphSampler(MiniBatchTransformer):
if minibatch.node_pairs is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None
)
else:
raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
......@@ -54,10 +58,11 @@ class SubgraphSampler(MiniBatchTransformer):
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(seeds)
) = self.sample_subgraphs(seeds, seeds_timestamp)
return minibatch
def _node_pairs_preprocess(self, minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None
......@@ -72,20 +77,44 @@ class SubgraphSampler(MiniBatchTransformer):
)
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src)
nodes[dst_type].append(dst)
if use_timestamp:
nodes_timestamp[src_type].append(minibatch.timestamp[etype])
nodes_timestamp[dst_type].append(minibatch.timestamp[etype])
if has_neg_src:
for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1))
if use_timestamp:
nodes_timestamp[src_type].append(
minibatch.timestamp[etype].repeat_interleave(
src.shape[-1]
)
)
if has_neg_dst:
for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1))
if use_timestamp:
nodes_timestamp[dst_type].append(
minibatch.timestamp[etype].repeat_interleave(
dst.shape[-1]
)
)
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
(
compacted_node_pairs,
compacted_negative_srcs,
......@@ -108,12 +137,30 @@ class SubgraphSampler(MiniBatchTransformer):
else:
# Collect nodes from all types of input.
nodes = list(node_pairs)
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
nodes_timestamp = [minibatch.timestamp, minibatch.timestamp]
if has_neg_src:
nodes.append(neg_src.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_src.shape[-1])
)
if has_neg_dst:
nodes.append(neg_dst.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_dst.shape[-1])
)
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect.
compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:]
......@@ -132,13 +179,14 @@ class SubgraphSampler(MiniBatchTransformer):
)
return (
seeds,
nodes_timestamp,
compacted_node_pairs,
compacted_negative_srcs if has_neg_src else None,
compacted_negative_dsts if has_neg_dst else None,
)
def sample_subgraphs(self, seeds, seeds_timestamp=None):
"""Sample subgraphs from the given seeds.
def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
Any subclass of SubgraphSampler should implement this method.
......@@ -147,6 +195,11 @@ class SubgraphSampler(MiniBatchTransformer):
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
The seed nodes.
seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]]
The timestamps of the seed nodes. If given, the sampled subgraphs
should not contain any nodes or edges that are newer than the
timestamps of the seed nodes. Default: None.
Returns
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
......
import unittest
from enum import Enum
from functools import partial
import backend as F
......@@ -12,6 +14,31 @@ from torchdata.datapipes.iter import Mapper
from . import gb_test_utils
# Skip all tests on GPU.
pytestmark = pytest.mark.skipif(
F._default_context_str != "cpu",
reason="GraphBolt sampling tests are only supported on CPU.",
)
class SamplerType(Enum):
Normal = 0
Layer = 1
Temporal = 2
def _get_sampler(sampler_type):
if sampler_type == SamplerType.Normal:
return gb.NeighborSampler
if sampler_type == SamplerType.Layer:
return gb.LayerNeighborSampler
return partial(
gb.TemporalNeighborSampler,
node_timestamp_attr_name="timestamp",
edge_timestamp_attr_name="timestamp",
)
def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......@@ -76,17 +103,29 @@ def test_NeighborSampler_fanouts(labor):
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node(sampler_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
items = torch.arange(10)
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = ("seed_nodes", "timestamp")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(item_sampler, graph, fanouts)
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5
......@@ -95,33 +134,57 @@ def to_link_batch(data):
return block
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link(sampler_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
items = torch.arange(20).reshape(-1, 2)
names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = ("node_pairs", "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_With_Negative(sampler_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
items = torch.arange(20).reshape(-1, 2)
names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = ("node_pairs", "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
......@@ -148,34 +211,64 @@ def get_hetero_graph():
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node_Hetero(sampler_type):
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")}
)
items = torch.arange(3)
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(item_sampler, graph, fanouts)
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2
for minibatch in sampler_dp:
assert len(minibatch.sampled_subgraphs) == num_layer
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero(sampler_type):
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
names="node_pairs",
first_items,
names=first_names,
),
"n2:e2:n1": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
names="node_pairs",
second_items,
names=second_names,
),
}
)
......@@ -183,24 +276,42 @@ def test_SubgraphSampler_Link_Hetero(labor):
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
names="node_pairs",
first_items,
names=first_names,
),
"n2:e2:n1": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
names="node_pairs",
second_items,
names=second_names,
),
}
)
......@@ -209,8 +320,8 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
......@@ -219,8 +330,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
F._default_context_str != "cpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Random_Hetero_Graph(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type):
num_nodes = 5
num_edges = 9
num_ntypes = 3
......@@ -235,10 +349,14 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes
)
node_attributes = {}
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
if sampler_type == SamplerType.Temporal:
node_attributes["timestamp"] = torch.randint(0, 10, (num_nodes,))
edge_attributes["timestamp"] = torch.randint(0, 10, (num_edges,))
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
......@@ -246,21 +364,31 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
).to(F.ctx())
first_items = torch.tensor([0])
first_names = "seed_nodes"
second_items = torch.tensor([0])
second_names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
first_items = (first_items, torch.randint(0, 10, (1,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (1,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict(
{
"n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
"n1": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
"n2": gb.ItemSet(first_items, names=first_names),
"n1": gb.ItemSet(second_items, names=second_names),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler = _get_sampler(sampler_type)
sampler_dp = Sampler(item_sampler, graph, fanouts, replace=True)
sampler_dp = sampler(item_sampler, graph, fanouts, replace=True)
for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs:
......@@ -289,23 +417,40 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
F._default_context_str != "cpu",
reason="Fails due to randomness on the GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Homo(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([0, 3, 4])
items = seed_nodes
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
F.ctx()
)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
length = [17, 7]
compacted_indices = [
......@@ -334,17 +479,32 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Hetero(labor):
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
)
items = torch.arange(2)
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment