Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
dbafbe41
Unverified
Commit
dbafbe41
authored
Feb 26, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Feb 26, 2024
Browse files
[GraphBolt] Fix scalar itemset dtype issue (#7147)
parent
045beeba
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
15 deletions
+54
-15
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+41
-13
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+1
-1
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+12
-1
No files found.
python/dgl/graphbolt/itemset.py
View file @
dbafbe41
...
...
@@ -8,6 +8,13 @@ import torch
__all__
=
[
"ItemSet"
,
"ItemSetDict"
]
def
is_scalar
(
x
):
"""Checks if the input is a scalar."""
return
(
len
(
x
.
shape
)
==
0
if
isinstance
(
x
,
torch
.
Tensor
)
else
isinstance
(
x
,
int
)
)
class
ItemSet
:
r
"""A wrapper of iterable data or tuple of iterable data.
...
...
@@ -47,7 +54,22 @@ class ItemSet:
>>> item_set.names
('seed_nodes',)
2. Single iterable: seed nodes.
2. Torch scalar: number of nodes. Customizable dtype compared to Integer.
>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seed_nodes',)
3. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
...
...
@@ -58,7 +80,7 @@ class ItemSet:
>>> item_set.names
('seed_nodes',)
3
. Tuple of iterables with same shape: seed nodes and labels.
4
. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
...
...
@@ -72,7 +94,7 @@ class ItemSet:
>>> item_set.names
('seed_nodes', 'labels')
4
. Tuple of iterables with different shape: node pairs and negative dsts.
5
. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
...
...
@@ -94,10 +116,10 @@ class ItemSet:
def
__init__
(
self
,
items
:
Union
[
int
,
Iterable
,
Tuple
[
Iterable
]],
items
:
Union
[
int
,
torch
.
Tensor
,
Iterable
,
Tuple
[
Iterable
]],
names
:
Union
[
str
,
Tuple
[
str
]]
=
None
,
)
->
None
:
if
isinstance
(
items
,
(
int
,
tuple
)):
if
isinstance
(
items
,
tuple
)
or
is_scalar
(
items
):
self
.
_items
=
items
else
:
self
.
_items
=
(
items
,)
...
...
@@ -117,8 +139,9 @@ class ItemSet:
self
.
_names
=
None
def
__iter__
(
self
)
->
Iterator
:
if
isinstance
(
self
.
_items
,
int
):
yield
from
torch
.
arange
(
self
.
_items
)
if
is_scalar
(
self
.
_items
):
dtype
=
getattr
(
self
.
_items
,
"dtype"
,
torch
.
int64
)
yield
from
torch
.
arange
(
self
.
_items
,
dtype
=
dtype
)
return
if
len
(
self
.
_items
)
==
1
:
...
...
@@ -143,8 +166,8 @@ class ItemSet:
yield
tuple
(
item
)
def
__len__
(
self
)
->
int
:
if
is
instance
(
self
.
_items
,
int
):
return
self
.
_items
if
is
_scalar
(
self
.
_items
):
return
int
(
self
.
_items
)
if
isinstance
(
self
.
_items
[
0
],
Sized
):
return
len
(
self
.
_items
[
0
])
raise
TypeError
(
...
...
@@ -158,10 +181,11 @@ class ItemSet:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't support indexing."
)
if
is
instance
(
self
.
_items
,
int
):
if
is
_scalar
(
self
.
_items
):
if
isinstance
(
idx
,
slice
):
start
,
stop
,
step
=
idx
.
indices
(
self
.
_items
)
return
torch
.
arange
(
start
,
stop
,
step
)
start
,
stop
,
step
=
idx
.
indices
(
int
(
self
.
_items
))
dtype
=
getattr
(
self
.
_items
,
"dtype"
,
torch
.
int64
)
return
torch
.
arange
(
start
,
stop
,
step
,
dtype
=
dtype
)
if
isinstance
(
idx
,
int
):
if
idx
<
0
:
idx
+=
self
.
_items
...
...
@@ -169,7 +193,11 @@ class ItemSet:
raise
IndexError
(
f
"
{
type
(
self
).
__name__
}
index out of range."
)
return
idx
return
(
torch
.
tensor
(
idx
,
dtype
=
self
.
_items
.
dtype
)
if
isinstance
(
self
.
_items
,
torch
.
Tensor
)
else
idx
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
indices must be integer or slice."
)
...
...
python/dgl/graphbolt/minibatch.py
View file @
dbafbe41
...
...
@@ -231,7 +231,7 @@ class MiniBatch:
self
.
sampled_subgraphs
[
0
].
sampled_csc
,
Dict
)
#
c
asts to minimum dtype in-place and returns self.
#
C
asts to minimum dtype in-place and returns self.
def
cast_to_minimum_dtype
(
v
:
CSCFormatBase
):
# Checks if number of vertices and edges fit into an int32.
dtype
=
(
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
dbafbe41
...
...
@@ -4,7 +4,6 @@ import dgl
import
pytest
import
torch
from
dgl
import
graphbolt
as
gb
from
torch.testing
import
assert_close
def
test_ItemSet_names
():
...
...
@@ -38,6 +37,18 @@ def test_ItemSet_names():
_
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
(
"seed_nodes"
,
"labels"
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_ItemSet_scalar_dtype
(
dtype
):
item_set
=
gb
.
ItemSet
(
torch
.
tensor
(
5
,
dtype
=
dtype
),
names
=
"seed_nodes"
)
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
assert
item
.
dtype
==
dtype
assert
item_set
[
2
]
==
torch
.
tensor
(
2
,
dtype
=
dtype
)
assert
torch
.
equal
(
item_set
[
slice
(
1
,
4
,
2
)],
torch
.
arange
(
1
,
4
,
2
,
dtype
=
dtype
)
)
def
test_ItemSet_length
():
# Integer with valid length
num
=
10
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment