"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "57d2f31f20f124a5fd93077060d2f189faed0eb8"
Unverified Commit c6abbb13 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Add a check assertion for data type of nodes (#6767)

parent 6451807b
...@@ -629,6 +629,10 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -629,6 +629,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _check_sampler_arguments(self, nodes, fanouts, probs_name): def _check_sampler_arguments(self, nodes, fanouts, probs_name):
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert nodes.dtype == self.indices.dtype, (
f"Data type of nodes must be consistent with "
f"indices.dtype({self.indices.dtype}), but got {nodes.dtype}."
)
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor." assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1 expected_fanout_len = 1
if self.edge_type_to_id: if self.edge_type_to_id:
......
import os import os
import pickle import pickle
import re
import tempfile import tempfile
import unittest import unittest
...@@ -972,9 +973,25 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype): ...@@ -972,9 +973,25 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
graph = gb.fused_csc_sampling_graph(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype) fanouts = torch.LongTensor([2])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = torch.tensor(
[1, 3, 4],
dtype=(torch.int64 if indices_dtype == torch.int32 else torch.int32),
)
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.node_pairs[0].size(0)
...@@ -1023,12 +1040,37 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1023,12 +1040,37 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
) )
# Sample on both node types. # Sample on both node types.
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
"n2": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = { nodes = {
"n1": torch.tensor([0], dtype=indices_dtype), "n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype), "n2": torch.tensor([0], dtype=indices_dtype),
} }
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
...@@ -1051,20 +1093,39 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1051,20 +1093,39 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
assert subgraph.original_edge_ids is None assert subgraph.original_edge_ids is None
# Sample on single node type. # Sample on single node type.
nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
)
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {"n1": torch.tensor([0], dtype=indices_dtype)}
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_node_pairs = {
"n2:e2:n1": ( "n2:e2:n1": (
torch.LongTensor([0, 2]), torch.tensor([0, 2], dtype=indices_dtype),
torch.LongTensor([0, 0]), torch.tensor([0, 0], dtype=indices_dtype),
), ),
"n1:e1:n2": ( "n1:e1:n2": (
torch.LongTensor([]), torch.tensor([], dtype=indices_dtype),
torch.LongTensor([]), torch.tensor([], dtype=indices_dtype),
), ),
} }
assert len(subgraph.node_pairs) == 2 assert len(subgraph.node_pairs) == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment