Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3774945d
Unverified
Commit
3774945d
authored
Oct 18, 2023
by
Rhett Ying
Committed by
GitHub
Oct 18, 2023
Browse files
[GraphBolt] enable slice for integer-init ItemSet (#6457)
parent
b08c446d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
15 deletions
+64
-15
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+10
-1
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+14
-6
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+28
-0
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+12
-8
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
3774945d
...
@@ -134,12 +134,21 @@ class ItemShufflerAndBatcher:
...
@@ -134,12 +134,21 @@ class ItemShufflerAndBatcher:
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
if
len
(
self
.
_item_set
.
_items
)
==
1
:
if
isinstance
(
self
.
_item_set
.
_items
,
int
):
# For integer-initialized item set, `buffer` is a tensor.
yield
buffer
[
batch_indices
]
elif
len
(
self
.
_item_set
.
_items
)
==
1
:
if
isinstance
(
buffer
[
0
],
DGLGraph
):
if
isinstance
(
buffer
[
0
],
DGLGraph
):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
yield
dgl_batch
([
buffer
[
idx
]
for
idx
in
batch_indices
])
yield
dgl_batch
([
buffer
[
idx
]
for
idx
in
batch_indices
])
else
:
else
:
# For item set that's initialized with a single
# tensor, `buffer` is a tensor.
yield
buffer
[
batch_indices
]
yield
buffer
[
batch_indices
]
else
:
else
:
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
yield
tuple
(
item
[
batch_indices
]
for
item
in
buffer
)
yield
tuple
(
item
[
batch_indices
]
for
item
in
buffer
)
buffer
=
None
buffer
=
None
start
=
end
start
=
end
...
...
python/dgl/graphbolt/itemset.py
View file @
3774945d
...
@@ -37,7 +37,7 @@ class ItemSet:
...
@@ -37,7 +37,7 @@ class ItemSet:
>>> list(item_set)
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
tensor(6), tensor(7), tensor(8), tensor(9)]
tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[
torch.arange(0, num)
]
>>> item_set[
:
]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
>>> item_set.names
('seed_nodes',)
('seed_nodes',)
...
@@ -151,12 +151,20 @@ class ItemSet:
...
@@ -151,12 +151,20 @@ class ItemSet:
f
"
{
type
(
self
).
__name__
}
instance doesn't support indexing."
f
"
{
type
(
self
).
__name__
}
instance doesn't support indexing."
)
)
if
isinstance
(
self
.
_items
,
int
):
if
isinstance
(
self
.
_items
,
int
):
assert
isinstance
(
idx
,
(
int
,
torch
.
Tensor
)),
(
if
isinstance
(
idx
,
slice
):
f
"Indexing of integer-initialized
{
type
(
self
).
__name__
}
"
start
,
stop
,
step
=
idx
.
indices
(
self
.
_items
)
f
"instance must be int or torch.Tensor."
return
torch
.
arange
(
start
,
stop
,
step
)
if
isinstance
(
idx
,
int
):
if
idx
<
0
:
idx
+=
self
.
_items
if
idx
<
0
or
idx
>=
self
.
_items
:
raise
IndexError
(
f
"
{
type
(
self
).
__name__
}
index out of range."
)
)
# [Warning] Index range is not checked.
return
idx
return
idx
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
indices must be integer or slice."
)
if
len
(
self
.
_items
)
==
1
:
if
len
(
self
.
_items
)
==
1
:
return
self
.
_items
[
0
][
idx
]
return
self
.
_items
[
0
][
idx
]
return
tuple
(
item
[
idx
]
for
item
in
self
.
_items
)
return
tuple
(
item
[
idx
]
for
item
in
self
.
_items
)
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
3774945d
...
@@ -98,6 +98,34 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
...
@@ -98,6 +98,34 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_integer
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
num_ids
=
103
item_set
=
gb
.
ItemSet
(
num_ids
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
minibatch
.
labels
is
None
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
.
seed_nodes
)
==
batch_size
else
:
if
not
drop_last
:
assert
len
(
minibatch
.
seed_nodes
)
==
num_ids
%
batch_size
else
:
assert
False
minibatch_ids
.
append
(
minibatch
.
seed_nodes
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
3774945d
...
@@ -114,16 +114,20 @@ def test_ItemSet_seed_nodes():
...
@@ -114,16 +114,20 @@ def test_ItemSet_seed_nodes():
assert
i
==
item
.
item
()
assert
i
==
item
.
item
()
assert
i
==
item_set
[
i
]
assert
i
==
item_set
[
i
]
# Indexing with a slice.
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:],
torch
.
arange
(
0
,
5
))
# Indexing with an integer.
assert
item_set
[
0
]
==
0
assert
item_set
[
-
1
]
==
4
# Indexing that is out of range.
with
pytest
.
raises
(
IndexError
,
match
=
"ItemSet index out of range."
):
_
=
item_set
[
5
]
with
pytest
.
raises
(
IndexError
,
match
=
"ItemSet index out of range."
):
_
=
item_set
[
-
10
]
# Indexing with tensor.
with
pytest
.
raises
(
with
pytest
.
raises
(
AssertionError
,
TypeError
,
match
=
"ItemSet indices must be integer or slice."
match
=
(
"Indexing of integer-initialized ItemSet instance must be int or "
"torch.Tensor."
),
):
):
_
=
item_set
[:]
_
=
item_set
[
torch
.
arange
(
3
)]
# Indexing with an Tensor.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)],
torch
.
arange
(
0
,
5
))
def
test_ItemSet_seed_nodes_labels
():
def
test_ItemSet_seed_nodes_labels
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment