Unverified Commit 2f585940 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `_seeds_preprocess` in subgraph_sampler. (#7242)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 7815fe8a
...@@ -584,15 +584,17 @@ class MiniBatch: ...@@ -584,15 +584,17 @@ class MiniBatch:
"node_features", "node_features",
"edge_features", "edge_features",
] ]
elif self.seeds is not None and self.compacted_seeds is not None: elif self.seeds is not None:
# Node/link/edge related tasks. # Node/link/edge related tasks.
transfer_attrs = [ transfer_attrs = [
"labels", "labels",
"compacted_seeds",
"sampled_subgraphs", "sampled_subgraphs",
"node_features", "node_features",
"edge_features", "edge_features",
] ]
# Link/edge related tasks.
if self.compacted_seeds is not None:
transfer_attrs.append("compacted_seeds")
if self.indexes is not None: if self.indexes is not None:
transfer_attrs.append("indexes") transfer_attrs.append("indexes")
if self.labels is None: if self.labels is None:
......
...@@ -266,24 +266,32 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -266,24 +266,32 @@ class SubgraphSampler(MiniBatchTransformer):
nodes_timestamp = None nodes_timestamp = None
if use_timestamp: if use_timestamp:
nodes_timestamp = defaultdict(list) nodes_timestamp = defaultdict(list)
for etype, pair in seeds.items(): for etype, typed_seeds in seeds.items():
assert pair.ndim == 1 or ( # When typed_seeds is a one-dimensional tensor, it represents
pair.ndim == 2 and pair.shape[1] == 2 # seed nodes, which does not need to do unique and compact.
), ( if typed_seeds.ndim == 1:
nodes_timestamp = (
minibatch.timestamp
if hasattr(minibatch, "timestamp")
else None
)
return seeds, nodes_timestamp, None
assert typed_seeds.ndim == 2 and typed_seeds.shape[1] == 2, (
"Only tensor with shape 1*N and N*2 is " "Only tensor with shape 1*N and N*2 is "
+ f"supported now, but got {pair.shape}." + f"supported now, but got {typed_seeds.shape}."
) )
ntypes = etype[:].split(":")[::2] ntypes = etype[:].split(":")[::2]
pair = pair.view(pair.shape[0], -1)
if use_timestamp: if use_timestamp:
negative_ratio = ( negative_ratio = (
pair.shape[0] // minibatch.timestamp[etype].shape[0] - 1 typed_seeds.shape[0]
// minibatch.timestamp[etype].shape[0]
- 1
) )
neg_timestamp = minibatch.timestamp[ neg_timestamp = minibatch.timestamp[
etype etype
].repeat_interleave(negative_ratio) ].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes): for i, ntype in enumerate(ntypes):
nodes[ntype].append(pair[:, i]) nodes[ntype].append(typed_seeds[:, i])
if use_timestamp: if use_timestamp:
nodes_timestamp[ntype].append( nodes_timestamp[ntype].append(
minibatch.timestamp[etype] minibatch.timestamp[etype]
...@@ -301,15 +309,21 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -301,15 +309,21 @@ class SubgraphSampler(MiniBatchTransformer):
nodes_timestamp = None nodes_timestamp = None
compacted_seeds = {} compacted_seeds = {}
# Map back in same order as collect. # Map back in same order as collect.
for etype, pair in seeds.items(): for etype, typed_seeds in seeds.items():
if pair.ndim == 1:
compacted_seeds[etype] = compacted[etype].pop(0)
else:
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0) src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0) dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
else: else:
# When seeds is a one-dimensional tensor, it represents seed nodes,
# which does not need to do unique and compact.
if seeds.ndim == 1:
nodes_timestamp = (
minibatch.timestamp
if hasattr(minibatch, "timestamp")
else None
)
return seeds, nodes_timestamp, None
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = [seeds.view(-1)] nodes = [seeds.view(-1)]
nodes_timestamp = None nodes_timestamp = None
......
...@@ -231,7 +231,6 @@ def test_CopyToWithMiniBatches(task): ...@@ -231,7 +231,6 @@ def test_CopyToWithMiniBatches(task):
copied_attrs = [ copied_attrs = [
"node_features", "node_features",
"edge_features", "edge_features",
"compacted_seeds",
"sampled_subgraphs", "sampled_subgraphs",
"labels", "labels",
"blocks", "blocks",
...@@ -239,7 +238,6 @@ def test_CopyToWithMiniBatches(task): ...@@ -239,7 +238,6 @@ def test_CopyToWithMiniBatches(task):
elif task == "node_inference": elif task == "node_inference":
copied_attrs = [ copied_attrs = [
"seeds", "seeds",
"compacted_seeds",
"sampled_subgraphs", "sampled_subgraphs",
"blocks", "blocks",
"labels", "labels",
...@@ -258,7 +256,6 @@ def test_CopyToWithMiniBatches(task): ...@@ -258,7 +256,6 @@ def test_CopyToWithMiniBatches(task):
copied_attrs = [ copied_attrs = [
"node_features", "node_features",
"edge_features", "edge_features",
"compacted_seeds",
"sampled_subgraphs", "sampled_subgraphs",
"labels", "labels",
"blocks", "blocks",
......
...@@ -1027,10 +1027,6 @@ def test_SubgraphSampler_Node(sampler_type): ...@@ -1027,10 +1027,6 @@ def test_SubgraphSampler_Node(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5 assert len(list(sampler_dp)) == 5
for data in sampler_dp:
assert torch.equal(
data.compacted_seeds, torch.tensor([0, 1]).to(F.ctx())
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1119,14 +1115,8 @@ def test_SubgraphSampler_Node_Hetero(sampler_type): ...@@ -1119,14 +1115,8 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2 assert len(list(sampler_dp)) == 2
expected_compacted_seeds = {"n2": [torch.tensor([0, 1]), torch.tensor([0])]}
for step, minibatch in enumerate(sampler_dp): for step, minibatch in enumerate(sampler_dp):
assert len(minibatch.sampled_subgraphs) == num_layer assert len(minibatch.sampled_subgraphs) == num_layer
for etype, compacted_seeds in minibatch.compacted_seeds.items():
assert torch.equal(
compacted_seeds,
expected_compacted_seeds[etype][step].to(F.ctx()),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment