Unverified Commit 330571b6 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Fix asseration bug (#5952)

parent 8adb53bb
......@@ -346,6 +346,9 @@ torch::Tensor PickByEtype(
const auto end = offset + num_neighbors;
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < fanouts.size(),
"Etype values exceed the number of fanouts.");
int64_t fanout = fanouts[etype];
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
......
......@@ -270,6 +270,14 @@ class CSCSamplingGraph:
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id)
assert len(fanouts) in [
expected_fanout_len,
1,
], "Fanouts should have the same number of elements as etypes or \
should have a length of 1."
if fanouts.size(0) > 1:
assert (
self.type_per_edge is not None
......@@ -279,10 +287,6 @@ class CSCSamplingGraph:
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
if self.metadata and self.metadata.edge_type_to_id:
assert len(self.metadata.edge_type_to_id) == fanouts.size(
0
), "Fanouts should have the same number of elements as etypes."
if probs_or_mask is not None:
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
......
......@@ -407,9 +407,13 @@ def test_sample_neighbors():
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
ntypes = {"n1": 0, "n2": 1, "n3": 2}
etypes = {("n1", "e1", "n2"): 0, ("n1", "e2", "n3"): 1}
metadata = gb.GraphMetadata(ntypes, etypes)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
graph = gb.from_csc(
indptr, indices, type_per_edge=type_per_edge, metadata=metadata
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -467,9 +471,14 @@ def test_sample_neighbors_fanouts(fanouts, expected_sampled_num):
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
ntypes = {"n1": 0, "n2": 1, "n3": 2}
etypes = {("n1", "e1", "n2"): 0, ("n1", "e2", "n3"): 1}
metadata = gb.GraphMetadata(ntypes, etypes)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
graph = gb.from_csc(
indptr, indices, type_per_edge=type_per_edge, metadata=metadata
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment