Unverified Commit 7c51cd16 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Cast sampled data to minimum dtype. (#7131)

parent 8909d1ff
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import dgl import dgl
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from .base import etype_str_to_tuple, expand_indptr from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
from .internal import get_attributes from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
...@@ -231,6 +231,19 @@ class MiniBatch: ...@@ -231,6 +231,19 @@ class MiniBatch:
self.sampled_subgraphs[0].sampled_csc, Dict self.sampled_subgraphs[0].sampled_csc, Dict
) )
# casts to minimum dtype in-place and returns self.
def cast_to_minimum_dtype(v: CSCFormatBase):
# Checks if number of vertices and edges fit into an int32.
dtype = (
torch.int32
if max(v.indptr.size(0) - 2, v.indices.size(0))
<= torch.iinfo(torch.int32).max
else torch.int64
)
v.indptr = v.indptr.to(dtype)
v.indices = v.indices.to(dtype)
return v
blocks = [] blocks = []
for subgraph in self.sampled_subgraphs: for subgraph in self.sampled_subgraphs:
original_row_node_ids = subgraph.original_row_node_ids original_row_node_ids = subgraph.original_row_node_ids
...@@ -242,6 +255,8 @@ class MiniBatch: ...@@ -242,6 +255,8 @@ class MiniBatch:
original_column_node_ids is not None original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph." ), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous: if is_heterogeneous:
for v in subgraph.sampled_csc.values():
cast_to_minimum_dtype(v)
sampled_csc = { sampled_csc = {
etype_str_to_tuple(etype): ( etype_str_to_tuple(etype): (
"csc", "csc",
...@@ -267,7 +282,7 @@ class MiniBatch: ...@@ -267,7 +282,7 @@ class MiniBatch:
for ntype, nodes in original_column_node_ids.items() for ntype, nodes in original_column_node_ids.items()
} }
else: else:
sampled_csc = subgraph.sampled_csc sampled_csc = cast_to_minimum_dtype(subgraph.sampled_csc)
sampled_csc = ( sampled_csc = (
"csc", "csc",
( (
......
...@@ -62,15 +62,15 @@ def test_integration_link_prediction(): ...@@ -62,15 +62,15 @@ def test_integration_link_prediction():
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([0, 4]), indices=tensor([0, 4], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([5, 4]), indices=tensor([5, 4], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
...@@ -121,15 +121,15 @@ def test_integration_link_prediction(): ...@@ -121,15 +121,15 @@ def test_integration_link_prediction():
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
indices=tensor([4, 1, 0]), indices=tensor([4, 1, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
indices=tensor([4, 4, 0]), indices=tensor([4, 4, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None, original_edge_ids=None,
...@@ -180,15 +180,15 @@ def test_integration_link_prediction(): ...@@ -180,15 +180,15 @@ def test_integration_link_prediction():
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0]), indices=tensor([1, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 4, 0, 1]), original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]), original_column_node_ids=tensor([5, 4, 0, 1]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0]), indices=tensor([1, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 4, 0, 1]), original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=None, original_edge_ids=None,
...@@ -287,15 +287,15 @@ def test_integration_node_classification(): ...@@ -287,15 +287,15 @@ def test_integration_node_classification():
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 1, 0, 1]), indices=tensor([4, 1, 0, 1], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2, 4]), original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 3, 1, 2]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([0, 1, 0, 1]), indices=tensor([0, 1, 0, 1], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2]), original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None, original_edge_ids=None,
...@@ -331,15 +331,15 @@ def test_integration_node_classification(): ...@@ -331,15 +331,15 @@ def test_integration_node_classification():
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2]), indices=tensor([0, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0]), original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([3, 4, 0]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2]), indices=tensor([0, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0]), original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
...@@ -373,15 +373,15 @@ def test_integration_node_classification(): ...@@ -373,15 +373,15 @@ def test_integration_node_classification():
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2]), indices=tensor([0, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 4, 0]), original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 1]), indices=tensor([1, 1], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 4]), original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
......
...@@ -8,15 +8,17 @@ relation = "A:r:B" ...@@ -8,15 +8,17 @@ relation = "A:r:B"
reverse_relation = "B:rr:A" reverse_relation = "B:rr:A"
def test_minibatch_representation_homo(): @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
csc_formats = [ csc_formats = [
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]), indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 2, 2, 1, 2]), indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),
), ),
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 3]), indptr=torch.tensor([0, 2, 3], dtype=indptr_dtype),
indices=torch.tensor([1, 2, 0]), indices=torch.tensor([1, 2, 0], dtype=indices_dtype),
), ),
] ]
original_column_node_ids = [ original_column_node_ids = [
...@@ -98,15 +100,15 @@ def test_minibatch_representation_homo(): ...@@ -98,15 +100,15 @@ def test_minibatch_representation_homo():
expect_result = str( expect_result = str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2]), indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([10, 11, 12, 13]), original_row_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]), original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_column_node_ids=tensor([10, 11, 12, 13]), original_column_node_ids=tensor([10, 11, 12, 13]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
indices=tensor([1, 2, 0]), indices=tensor([1, 2, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([10, 11, 12]), original_row_node_ids=tensor([10, 11, 12]),
original_edge_ids=tensor([10, 15, 17]), original_edge_ids=tensor([10, 15, 17]),
...@@ -119,11 +121,11 @@ def test_minibatch_representation_homo(): ...@@ -119,11 +121,11 @@ def test_minibatch_representation_homo():
indices=tensor([3, 4, 5]), indices=tensor([3, 4, 5]),
), ),
tensor([0., 1., 2.])), tensor([0., 1., 2.])),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]), node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2]), indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
), ),
CSCFormatBase(indptr=tensor([0, 2, 3]), CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
indices=tensor([1, 2, 0]), indices=tensor([1, 2, 0], dtype=torch.int32),
)], )],
node_features={'x': tensor([5, 0, 2, 1])}, node_features={'x': tensor([5, 0, 2, 1])},
negative_srcs=tensor([[8], negative_srcs=tensor([[8],
...@@ -161,21 +163,24 @@ def test_minibatch_representation_homo(): ...@@ -161,21 +163,24 @@ def test_minibatch_representation_homo():
assert result == expect_result, print(expect_result, result) assert result == expect_result, print(expect_result, result)
def test_minibatch_representation_hetero(): @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
csc_formats = [ csc_formats = [
{ {
relation: gb.CSCFormatBase( relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]), indptr=torch.tensor([0, 1, 2, 3], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 1]), indices=torch.tensor([0, 1, 1], dtype=indices_dtype),
), ),
reverse_relation: gb.CSCFormatBase( reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]), indptr=torch.tensor([0, 0, 0, 1, 2], dtype=indptr_dtype),
indices=torch.tensor([1, 0]), indices=torch.tensor([1, 0], dtype=indices_dtype),
), ),
}, },
{ {
relation: gb.CSCFormatBase( relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0]) indptr=torch.tensor([0, 1, 2], dtype=indptr_dtype),
indices=torch.tensor([1, 0], dtype=indices_dtype),
) )
}, },
] ]
...@@ -250,17 +255,17 @@ def test_minibatch_representation_hetero(): ...@@ -250,17 +255,17 @@ def test_minibatch_representation_hetero():
expect_result = str( expect_result = str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes={'B': tensor([10, 15])}, seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1]), indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]), ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0]), indices=tensor([1, 0], dtype=torch.int32),
)}, )},
original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])}, original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])}, original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])}, original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
), ),
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]), SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0]), indices=tensor([1, 0], dtype=torch.int32),
)}, )},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])}, original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])}, original_edge_ids={'A:r:B': tensor([10, 12])},
...@@ -277,13 +282,13 @@ def test_minibatch_representation_hetero(): ...@@ -277,13 +282,13 @@ def test_minibatch_representation_hetero():
indices=tensor([0, 1]), indices=tensor([0, 1]),
)}, )},
{'B': tensor([2, 5])}), {'B': tensor([2, 5])}),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1]), indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]), ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0]), indices=tensor([1, 0], dtype=torch.int32),
)}, )},
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]), {'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0]), indices=tensor([1, 0], dtype=torch.int32),
)}], )}],
node_features={('A', 'x'): tensor([6, 4, 0, 1])}, node_features={('A', 'x'): tensor([6, 4, 0, 1])},
negative_srcs={'B': tensor([[8], negative_srcs={'B': tensor([[8],
...@@ -325,10 +330,12 @@ def test_minibatch_representation_hetero(): ...@@ -325,10 +330,12 @@ def test_minibatch_representation_hetero():
)""" )"""
) )
result = str(minibatch) result = str(minibatch)
assert result == expect_result, print(result) assert result == expect_result, print(expect_result, result)
def test_get_dgl_blocks_homo(): @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
node_pairs = [ node_pairs = [
( (
torch.tensor([0, 1, 2, 2, 2, 1]), torch.tensor([0, 1, 2, 2, 2, 1]),
...@@ -341,12 +348,12 @@ def test_get_dgl_blocks_homo(): ...@@ -341,12 +348,12 @@ def test_get_dgl_blocks_homo():
] ]
csc_formats = [ csc_formats = [
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]), indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 2, 2, 1, 2]), indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),
), ),
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]), indptr=torch.tensor([0, 1, 3], dtype=indptr_dtype),
indices=torch.tensor([0, 1, 2]), indices=torch.tensor([0, 1, 2], dtype=indices_dtype),
), ),
] ]
original_column_node_ids = [ original_column_node_ids = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment