Unverified Commit 6309483d authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Remove old version `subgraph sampler`. (#7305)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 65272a53
...@@ -57,20 +57,7 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -57,20 +57,7 @@ class SubgraphSampler(MiniBatchTransformer):
@staticmethod @staticmethod
def _preprocess(minibatch): def _preprocess(minibatch):
if minibatch.node_pairs is not None: if minibatch.seeds is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = SubgraphSampler._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None
)
elif minibatch.seeds is not None:
( (
seeds, seeds,
seeds_timestamp, seeds_timestamp,
...@@ -78,144 +65,12 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -78,144 +65,12 @@ class SubgraphSampler(MiniBatchTransformer):
) = SubgraphSampler._seeds_preprocess(minibatch) ) = SubgraphSampler._seeds_preprocess(minibatch)
else: else:
raise ValueError( raise ValueError(
f"Invalid minibatch {minibatch}: One of `node_pairs`, " f"Invalid minibatch {minibatch}: `seeds` should have a value."
"`seed_nodes` and `seeds` should have a value."
) )
minibatch._seed_nodes = seeds minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp minibatch._seeds_timestamp = seeds_timestamp
return minibatch return minibatch
@staticmethod
def _node_pairs_preprocess(minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pairs, Dict)
if is_heterogeneous:
has_neg_src = has_neg_src and all(
item is not None for item in neg_src.values()
)
has_neg_dst = has_neg_dst and all(
item is not None for item in neg_dst.values()
)
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src)
nodes[dst_type].append(dst)
if use_timestamp:
nodes_timestamp[src_type].append(minibatch.timestamp[etype])
nodes_timestamp[dst_type].append(minibatch.timestamp[etype])
if has_neg_src:
for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1))
if use_timestamp:
nodes_timestamp[src_type].append(
minibatch.timestamp[etype].repeat_interleave(
src.shape[-1]
)
)
if has_neg_dst:
for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1))
if use_timestamp:
nodes_timestamp[dst_type].append(
minibatch.timestamp[etype].repeat_interleave(
dst.shape[-1]
)
)
# Unique and compact the collected nodes.
if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
(
compacted_node_pairs,
compacted_negative_srcs,
compacted_negative_dsts,
) = ({}, {}, {})
# Map back in same order as collect.
for etype, _ in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_node_pairs[etype] = (src, dst)
if has_neg_src:
for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_srcs[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted_negative_srcs[
etype
].view(neg_src[etype].shape)
if has_neg_dst:
for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted_negative_dsts[
etype
].view(neg_dst[etype].shape)
else:
# Collect nodes from all types of input.
nodes = list(node_pairs)
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
nodes_timestamp = [minibatch.timestamp, minibatch.timestamp]
if has_neg_src:
nodes.append(neg_src.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_src.shape[-1])
)
if has_neg_dst:
nodes.append(neg_dst.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_dst.shape[-1])
)
# Unique and compact the collected nodes.
if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect.
compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:]
if has_neg_src:
compacted_negative_srcs = compacted.pop(0)
# Since we need to calculate the neg_ratio according to the
# compacted_negatvie_srcs shape, we need to reshape it back.
compacted_negative_srcs = compacted_negative_srcs.view(
neg_src.shape
)
if has_neg_dst:
compacted_negative_dsts = compacted.pop(0)
# Same as above.
compacted_negative_dsts = compacted_negative_dsts.view(
neg_dst.shape
)
return (
seeds,
nodes_timestamp,
compacted_node_pairs,
compacted_negative_srcs if has_neg_src else None,
compacted_negative_dsts if has_neg_dst else None,
)
def _sample(self, minibatch): def _sample(self, minibatch):
( (
minibatch.input_nodes, minibatch.input_nodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment