Unverified Commit 7c51cd16 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Cast sampled data to minimum dtype. (#7131)

parent 8909d1ff
......@@ -8,7 +8,7 @@ import torch
import dgl
from dgl.utils import recursive_apply
from .base import etype_str_to_tuple, expand_indptr
from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph
......@@ -231,6 +231,19 @@ class MiniBatch:
self.sampled_subgraphs[0].sampled_csc, Dict
)
# casts to minimum dtype in-place and returns self.
def cast_to_minimum_dtype(v: CSCFormatBase):
# Checks if number of vertices and edges fit into an int32.
dtype = (
torch.int32
if max(v.indptr.size(0) - 2, v.indices.size(0))
<= torch.iinfo(torch.int32).max
else torch.int64
)
v.indptr = v.indptr.to(dtype)
v.indices = v.indices.to(dtype)
return v
blocks = []
for subgraph in self.sampled_subgraphs:
original_row_node_ids = subgraph.original_row_node_ids
......@@ -242,6 +255,8 @@ class MiniBatch:
original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous:
for v in subgraph.sampled_csc.values():
cast_to_minimum_dtype(v)
sampled_csc = {
etype_str_to_tuple(etype): (
"csc",
......@@ -267,7 +282,7 @@ class MiniBatch:
for ntype, nodes in original_column_node_ids.items()
}
else:
sampled_csc = subgraph.sampled_csc
sampled_csc = cast_to_minimum_dtype(subgraph.sampled_csc)
sampled_csc = (
"csc",
(
......
......@@ -62,15 +62,15 @@ def test_integration_link_prediction():
str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([0, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([0, 4], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([5, 4], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
......@@ -121,15 +121,15 @@ def test_integration_link_prediction():
str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
indices=tensor([4, 1, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
indices=tensor([4, 1, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
indices=tensor([4, 4, 0]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
indices=tensor([4, 4, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None,
......@@ -180,15 +180,15 @@ def test_integration_link_prediction():
str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
indices=tensor([1, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
indices=tensor([1, 0]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=None,
......@@ -287,15 +287,15 @@ def test_integration_node_classification():
str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 1, 0, 1], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([0, 1, 0, 1], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None,
......@@ -331,15 +331,15 @@ def test_integration_node_classification():
str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
......@@ -373,15 +373,15 @@ def test_integration_node_classification():
str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 1], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
......
......@@ -8,15 +8,17 @@ relation = "A:r:B"
reverse_relation = "B:rr:A"
def test_minibatch_representation_homo():
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),
),
gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 3]),
indices=torch.tensor([1, 2, 0]),
indptr=torch.tensor([0, 2, 3], dtype=indptr_dtype),
indices=torch.tensor([1, 2, 0], dtype=indices_dtype),
),
]
original_column_node_ids = [
......@@ -98,15 +100,15 @@ def test_minibatch_representation_homo():
expect_result = str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_column_node_ids=tensor([10, 11, 12, 13]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
indices=tensor([1, 2, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([10, 11, 12]),
original_edge_ids=tensor([10, 15, 17]),
......@@ -119,11 +121,11 @@ def test_minibatch_representation_homo():
indices=tensor([3, 4, 5]),
),
tensor([0., 1., 2.])),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
),
CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
indices=tensor([1, 2, 0], dtype=torch.int32),
)],
node_features={'x': tensor([5, 0, 2, 1])},
negative_srcs=tensor([[8],
......@@ -161,21 +163,24 @@ def test_minibatch_representation_homo():
assert result == expect_result, print(expect_result, result)
def test_minibatch_representation_hetero():
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
csc_formats = [
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([0, 1, 1]),
indptr=torch.tensor([0, 1, 2, 3], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 1], dtype=indices_dtype),
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]),
indices=torch.tensor([1, 0]),
indptr=torch.tensor([0, 0, 0, 1, 2], dtype=indptr_dtype),
indices=torch.tensor([1, 0], dtype=indices_dtype),
),
},
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
indptr=torch.tensor([0, 1, 2], dtype=indptr_dtype),
indices=torch.tensor([1, 0], dtype=indices_dtype),
)
},
]
......@@ -250,17 +255,17 @@ def test_minibatch_representation_hetero():
expect_result = str(
"""MiniBatch(seeds=None,
seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)},
original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
),
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])},
......@@ -277,13 +282,13 @@ def test_minibatch_representation_hetero():
indices=tensor([0, 1]),
)},
{'B': tensor([2, 5])}),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)},
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)}],
node_features={('A', 'x'): tensor([6, 4, 0, 1])},
negative_srcs={'B': tensor([[8],
......@@ -325,10 +330,12 @@ def test_minibatch_representation_hetero():
)"""
)
result = str(minibatch)
assert result == expect_result, print(result)
assert result == expect_result, print(expect_result, result)
def test_get_dgl_blocks_homo():
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
......@@ -341,12 +348,12 @@ def test_get_dgl_blocks_homo():
]
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),
),
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([0, 1, 2]),
indptr=torch.tensor([0, 1, 3], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 2], dtype=indices_dtype),
),
]
original_column_node_ids = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment