Unverified Commit dbafbe41 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Fix scalar itemset dtype issue (#7147)

parent 045beeba
...@@ -8,6 +8,13 @@ import torch ...@@ -8,6 +8,13 @@ import torch
__all__ = ["ItemSet", "ItemSetDict"] __all__ = ["ItemSet", "ItemSetDict"]
def is_scalar(x):
"""Checks if the input is a scalar."""
return (
len(x.shape) == 0 if isinstance(x, torch.Tensor) else isinstance(x, int)
)
class ItemSet: class ItemSet:
r"""A wrapper of iterable data or tuple of iterable data. r"""A wrapper of iterable data or tuple of iterable data.
...@@ -47,7 +54,22 @@ class ItemSet: ...@@ -47,7 +54,22 @@ class ItemSet:
>>> item_set.names >>> item_set.names
('seed_nodes',) ('seed_nodes',)
2. Single iterable: seed nodes. 2. Torch scalar: number of nodes. Customizable dtype compared to Integer.
>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seed_nodes',)
3. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5) >>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes") >>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
...@@ -58,7 +80,7 @@ class ItemSet: ...@@ -58,7 +80,7 @@ class ItemSet:
>>> item_set.names >>> item_set.names
('seed_nodes',) ('seed_nodes',)
3. Tuple of iterables with same shape: seed nodes and labels. 4. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5) >>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10) >>> labels = torch.arange(5, 10)
...@@ -72,7 +94,7 @@ class ItemSet: ...@@ -72,7 +94,7 @@ class ItemSet:
>>> item_set.names >>> item_set.names
('seed_nodes', 'labels') ('seed_nodes', 'labels')
4. Tuple of iterables with different shape: node pairs and negative dsts. 5. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2) >>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3) >>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
...@@ -94,10 +116,10 @@ class ItemSet: ...@@ -94,10 +116,10 @@ class ItemSet:
def __init__( def __init__(
self, self,
items: Union[int, Iterable, Tuple[Iterable]], items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None, names: Union[str, Tuple[str]] = None,
) -> None: ) -> None:
if isinstance(items, (int, tuple)): if isinstance(items, tuple) or is_scalar(items):
self._items = items self._items = items
else: else:
self._items = (items,) self._items = (items,)
...@@ -117,8 +139,9 @@ class ItemSet: ...@@ -117,8 +139,9 @@ class ItemSet:
self._names = None self._names = None
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
if isinstance(self._items, int): if is_scalar(self._items):
yield from torch.arange(self._items) dtype = getattr(self._items, "dtype", torch.int64)
yield from torch.arange(self._items, dtype=dtype)
return return
if len(self._items) == 1: if len(self._items) == 1:
...@@ -143,8 +166,8 @@ class ItemSet: ...@@ -143,8 +166,8 @@ class ItemSet:
yield tuple(item) yield tuple(item)
def __len__(self) -> int: def __len__(self) -> int:
if isinstance(self._items, int): if is_scalar(self._items):
return self._items return int(self._items)
if isinstance(self._items[0], Sized): if isinstance(self._items[0], Sized):
return len(self._items[0]) return len(self._items[0])
raise TypeError( raise TypeError(
...@@ -158,10 +181,11 @@ class ItemSet: ...@@ -158,10 +181,11 @@ class ItemSet:
raise TypeError( raise TypeError(
f"{type(self).__name__} instance doesn't support indexing." f"{type(self).__name__} instance doesn't support indexing."
) )
if isinstance(self._items, int): if is_scalar(self._items):
if isinstance(idx, slice): if isinstance(idx, slice):
start, stop, step = idx.indices(self._items) start, stop, step = idx.indices(int(self._items))
return torch.arange(start, stop, step) dtype = getattr(self._items, "dtype", torch.int64)
return torch.arange(start, stop, step, dtype=dtype)
if isinstance(idx, int): if isinstance(idx, int):
if idx < 0: if idx < 0:
idx += self._items idx += self._items
...@@ -169,7 +193,11 @@ class ItemSet: ...@@ -169,7 +193,11 @@ class ItemSet:
raise IndexError( raise IndexError(
f"{type(self).__name__} index out of range." f"{type(self).__name__} index out of range."
) )
return idx return (
torch.tensor(idx, dtype=self._items.dtype)
if isinstance(self._items, torch.Tensor)
else idx
)
raise TypeError( raise TypeError(
f"{type(self).__name__} indices must be integer or slice." f"{type(self).__name__} indices must be integer or slice."
) )
......
...@@ -231,7 +231,7 @@ class MiniBatch: ...@@ -231,7 +231,7 @@ class MiniBatch:
self.sampled_subgraphs[0].sampled_csc, Dict self.sampled_subgraphs[0].sampled_csc, Dict
) )
# casts to minimum dtype in-place and returns self. # Casts to minimum dtype in-place and returns self.
def cast_to_minimum_dtype(v: CSCFormatBase): def cast_to_minimum_dtype(v: CSCFormatBase):
# Checks if number of vertices and edges fit into an int32. # Checks if number of vertices and edges fit into an int32.
dtype = ( dtype = (
......
...@@ -4,7 +4,6 @@ import dgl ...@@ -4,7 +4,6 @@ import dgl
import pytest import pytest
import torch import torch
from dgl import graphbolt as gb from dgl import graphbolt as gb
from torch.testing import assert_close
def test_ItemSet_names(): def test_ItemSet_names():
...@@ -38,6 +37,18 @@ def test_ItemSet_names(): ...@@ -38,6 +37,18 @@ def test_ItemSet_names():
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels")) _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype):
item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes")
for i, item in enumerate(item_set):
assert i == item
assert item.dtype == dtype
assert item_set[2] == torch.tensor(2, dtype=dtype)
assert torch.equal(
item_set[slice(1, 4, 2)], torch.arange(1, 4, 2, dtype=dtype)
)
def test_ItemSet_length(): def test_ItemSet_length():
# Integer with valid length # Integer with valid length
num = 10 num = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment