Unverified Commit dbafbe41 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Fix scalar itemset dtype issue (#7147)

parent 045beeba
......@@ -8,6 +8,13 @@ import torch
__all__ = ["ItemSet", "ItemSetDict"]
def is_scalar(x):
"""Checks if the input is a scalar."""
return (
len(x.shape) == 0 if isinstance(x, torch.Tensor) else isinstance(x, int)
)
class ItemSet:
r"""A wrapper of iterable data or tuple of iterable data.
......@@ -47,7 +54,22 @@ class ItemSet:
>>> item_set.names
('seed_nodes',)
2. Single iterable: seed nodes.
2. Torch scalar: number of nodes. Customizable dtype compared to Integer.
>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seed_nodes',)
3. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
......@@ -58,7 +80,7 @@ class ItemSet:
>>> item_set.names
('seed_nodes',)
3. Tuple of iterables with same shape: seed nodes and labels.
4. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
......@@ -72,7 +94,7 @@ class ItemSet:
>>> item_set.names
('seed_nodes', 'labels')
4. Tuple of iterables with different shape: node pairs and negative dsts.
5. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
......@@ -94,10 +116,10 @@ class ItemSet:
def __init__(
self,
items: Union[int, Iterable, Tuple[Iterable]],
items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None,
) -> None:
if isinstance(items, (int, tuple)):
if isinstance(items, tuple) or is_scalar(items):
self._items = items
else:
self._items = (items,)
......@@ -117,8 +139,9 @@ class ItemSet:
self._names = None
def __iter__(self) -> Iterator:
if isinstance(self._items, int):
yield from torch.arange(self._items)
if is_scalar(self._items):
dtype = getattr(self._items, "dtype", torch.int64)
yield from torch.arange(self._items, dtype=dtype)
return
if len(self._items) == 1:
......@@ -143,8 +166,8 @@ class ItemSet:
yield tuple(item)
def __len__(self) -> int:
if isinstance(self._items, int):
return self._items
if is_scalar(self._items):
return int(self._items)
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
......@@ -158,10 +181,11 @@ class ItemSet:
raise TypeError(
f"{type(self).__name__} instance doesn't support indexing."
)
if isinstance(self._items, int):
if is_scalar(self._items):
if isinstance(idx, slice):
start, stop, step = idx.indices(self._items)
return torch.arange(start, stop, step)
start, stop, step = idx.indices(int(self._items))
dtype = getattr(self._items, "dtype", torch.int64)
return torch.arange(start, stop, step, dtype=dtype)
if isinstance(idx, int):
if idx < 0:
idx += self._items
......@@ -169,7 +193,11 @@ class ItemSet:
raise IndexError(
f"{type(self).__name__} index out of range."
)
return idx
return (
torch.tensor(idx, dtype=self._items.dtype)
if isinstance(self._items, torch.Tensor)
else idx
)
raise TypeError(
f"{type(self).__name__} indices must be integer or slice."
)
......
......@@ -231,7 +231,7 @@ class MiniBatch:
self.sampled_subgraphs[0].sampled_csc, Dict
)
# casts to minimum dtype in-place and returns self.
# Casts to minimum dtype in-place and returns self.
def cast_to_minimum_dtype(v: CSCFormatBase):
# Checks if number of vertices and edges fit into an int32.
dtype = (
......
......@@ -4,7 +4,6 @@ import dgl
import pytest
import torch
from dgl import graphbolt as gb
from torch.testing import assert_close
def test_ItemSet_names():
......@@ -38,6 +37,18 @@ def test_ItemSet_names():
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype):
item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes")
for i, item in enumerate(item_set):
assert i == item
assert item.dtype == dtype
assert item_set[2] == torch.tensor(2, dtype=dtype)
assert torch.equal(
item_set[slice(1, 4, 2)], torch.arange(1, 4, 2, dtype=dtype)
)
def test_ItemSet_length():
# Integer with valid length
num = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment