Unverified Commit 2f585940 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `_seeds_preprocess` in subgraph_sampler. (#7242)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 7815fe8a
......@@ -584,15 +584,17 @@ class MiniBatch:
"node_features",
"edge_features",
]
elif self.seeds is not None and self.compacted_seeds is not None:
elif self.seeds is not None:
# Node/link/edge related tasks.
transfer_attrs = [
"labels",
"compacted_seeds",
"sampled_subgraphs",
"node_features",
"edge_features",
]
# Link/edge related tasks.
if self.compacted_seeds is not None:
transfer_attrs.append("compacted_seeds")
if self.indexes is not None:
transfer_attrs.append("indexes")
if self.labels is None:
......
......@@ -266,24 +266,32 @@ class SubgraphSampler(MiniBatchTransformer):
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, pair in seeds.items():
assert pair.ndim == 1 or (
pair.ndim == 2 and pair.shape[1] == 2
), (
for etype, typed_seeds in seeds.items():
# When typed_seeds is a one-dimensional tensor, it represents
# seed nodes, which does not need to do unique and compact.
if typed_seeds.ndim == 1:
nodes_timestamp = (
minibatch.timestamp
if hasattr(minibatch, "timestamp")
else None
)
return seeds, nodes_timestamp, None
assert typed_seeds.ndim == 2 and typed_seeds.shape[1] == 2, (
"Only tensor with shape 1*N and N*2 is "
+ f"supported now, but got {pair.shape}."
+ f"supported now, but got {typed_seeds.shape}."
)
ntypes = etype[:].split(":")[::2]
pair = pair.view(pair.shape[0], -1)
if use_timestamp:
negative_ratio = (
pair.shape[0] // minibatch.timestamp[etype].shape[0] - 1
typed_seeds.shape[0]
// minibatch.timestamp[etype].shape[0]
- 1
)
neg_timestamp = minibatch.timestamp[
etype
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(pair[:, i])
nodes[ntype].append(typed_seeds[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[etype]
......@@ -301,15 +309,21 @@ class SubgraphSampler(MiniBatchTransformer):
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for etype, pair in seeds.items():
if pair.ndim == 1:
compacted_seeds[etype] = compacted[etype].pop(0)
else:
for etype, typed_seeds in seeds.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
else:
# When seeds is a one-dimensional tensor, it represents seed nodes,
# which does not need to do unique and compact.
if seeds.ndim == 1:
nodes_timestamp = (
minibatch.timestamp
if hasattr(minibatch, "timestamp")
else None
)
return seeds, nodes_timestamp, None
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
nodes_timestamp = None
......
......@@ -231,7 +231,6 @@ def test_CopyToWithMiniBatches(task):
copied_attrs = [
"node_features",
"edge_features",
"compacted_seeds",
"sampled_subgraphs",
"labels",
"blocks",
......@@ -239,7 +238,6 @@ def test_CopyToWithMiniBatches(task):
elif task == "node_inference":
copied_attrs = [
"seeds",
"compacted_seeds",
"sampled_subgraphs",
"blocks",
"labels",
......@@ -258,7 +256,6 @@ def test_CopyToWithMiniBatches(task):
copied_attrs = [
"node_features",
"edge_features",
"compacted_seeds",
"sampled_subgraphs",
"labels",
"blocks",
......
......@@ -1027,10 +1027,6 @@ def test_SubgraphSampler_Node(sampler_type):
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5
for data in sampler_dp:
assert torch.equal(
data.compacted_seeds, torch.tensor([0, 1]).to(F.ctx())
)
@pytest.mark.parametrize(
......@@ -1119,14 +1115,8 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2
expected_compacted_seeds = {"n2": [torch.tensor([0, 1]), torch.tensor([0])]}
for step, minibatch in enumerate(sampler_dp):
assert len(minibatch.sampled_subgraphs) == num_layer
for etype, compacted_seeds in minibatch.compacted_seeds.items():
assert torch.equal(
compacted_seeds,
expected_compacted_seeds[etype][step].to(F.ctx()),
)
@pytest.mark.parametrize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment