Unverified Commit c6abbb13 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Add a check assertion for data type of nodes (#6767)

parent 6451807b
......@@ -629,6 +629,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert nodes.dtype == self.indices.dtype, (
f"Data type of nodes must be consistent with "
f"indices.dtype({self.indices.dtype}), but got {nodes.dtype}."
)
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.edge_type_to_id:
......
import os
import pickle
import re
import tempfile
import unittest
......@@ -972,9 +973,25 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
fanouts = torch.LongTensor([2])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = torch.tensor(
[1, 3, 4],
dtype=(torch.int64 if indices_dtype == torch.int32 else torch.int32),
)
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
......@@ -1023,12 +1040,37 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
)
# Sample on both node types.
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
"n2": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
......@@ -1051,20 +1093,39 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
assert subgraph.original_edge_ids is None
# Sample on single node type.
nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
)
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {"n1": torch.tensor([0], dtype=indices_dtype)}
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_node_pairs = {
"n2:e2:n1": (
torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]),
torch.tensor([0, 2], dtype=indices_dtype),
torch.tensor([0, 0], dtype=indices_dtype),
),
"n1:e1:n2": (
torch.LongTensor([]),
torch.LongTensor([]),
torch.tensor([], dtype=indices_dtype),
torch.tensor([], dtype=indices_dtype),
),
}
assert len(subgraph.node_pairs) == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment