Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
557c0a86
"references/vscode:/vscode.git/clone" did not exist on "4d928927d8217b2e25069af93ebc07ceaabcfcb7"
Unverified
Commit
557c0a86
authored
Oct 19, 2023
by
Rhett Ying
Committed by
GitHub
Oct 19, 2023
Browse files
[GraphBolt] enable indexing ItemSetDict in ItemSampler (#6468)
parent
09c33b9f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
94 additions
and
18 deletions
+94
-18
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+48
-18
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+46
-0
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
557c0a86
...
...
@@ -120,6 +120,52 @@ class ItemShufflerAndBatcher:
(
self
.
_buffer_size
+
batch_size
-
1
)
//
batch_size
*
batch_size
)
def
_collate_batch
(
self
,
buffer
,
indices
,
offsets
=
None
):
"""Collate a batch from the buffer. For internal use only."""
if
isinstance
(
buffer
,
torch
.
Tensor
):
# For item set that's initialized with integer or single tensor,
# `buffer` is a tensor.
return
buffer
[
indices
]
elif
isinstance
(
buffer
,
list
)
and
isinstance
(
buffer
[
0
],
DGLGraph
):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
return
dgl_batch
([
buffer
[
idx
]
for
idx
in
indices
])
elif
isinstance
(
buffer
,
tuple
):
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
return
tuple
(
item
[
indices
]
for
item
in
buffer
)
elif
isinstance
(
buffer
,
Mapping
):
# For item set that's initialized with a dict of items,
# `buffer` is a dict of tensors/lists/tuples.
keys
=
list
(
buffer
.
keys
())
key_indices
=
torch
.
searchsorted
(
offsets
,
indices
,
right
=
True
)
-
1
batch
=
{}
for
j
,
key
in
enumerate
(
keys
):
mask
=
(
key_indices
==
j
).
nonzero
().
squeeze
(
1
)
if
len
(
mask
)
==
0
:
continue
batch
[
key
]
=
self
.
_collate_batch
(
buffer
[
key
],
indices
[
mask
]
-
offsets
[
j
]
)
return
batch
raise
TypeError
(
f
"Unsupported buffer type
{
type
(
buffer
).
__name__
}
."
)
def
_calculate_offsets
(
self
,
buffer
):
"""Calculate offsets for each item in buffer. For internal use only."""
if
not
isinstance
(
buffer
,
Mapping
):
return
None
offsets
=
[
0
]
for
value
in
buffer
.
values
():
if
isinstance
(
value
,
torch
.
Tensor
):
offsets
.
append
(
offsets
[
-
1
]
+
len
(
value
))
elif
isinstance
(
value
,
tuple
):
offsets
.
append
(
offsets
[
-
1
]
+
len
(
value
[
0
]))
else
:
raise
TypeError
(
f
"Unsupported buffer type
{
type
(
value
).
__name__
}
."
)
return
torch
.
tensor
(
offsets
)
def
__iter__
(
self
):
buffer
=
None
num_items
=
len
(
self
.
_item_set
)
...
...
@@ -130,26 +176,12 @@ class ItemShufflerAndBatcher:
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
offsets
=
self
.
_calculate_offsets
(
buffer
)
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
if
isinstance
(
self
.
_item_set
.
_items
,
int
):
# For integer-initialized item set, `buffer` is a tensor.
yield
buffer
[
batch_indices
]
elif
len
(
self
.
_item_set
.
_items
)
==
1
:
if
isinstance
(
buffer
[
0
],
DGLGraph
):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
yield
dgl_batch
([
buffer
[
idx
]
for
idx
in
batch_indices
])
else
:
# For item set that's initialized with a single
# tensor, `buffer` is a tensor.
yield
buffer
[
batch_indices
]
else
:
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
yield
tuple
(
item
[
batch_indices
]
for
item
in
buffer
)
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
buffer
=
None
start
=
end
...
...
@@ -375,8 +407,6 @@ class ItemSampler(IterDataPipe):
item_set
[
0
]
except
TypeError
:
use_indexing
=
False
# [TODO][Rui] For now, we disable indexing for ItemSetDict.
use_indexing
=
(
not
isinstance
(
item_set
,
ItemSetDict
))
and
use_indexing
self
.
_use_indexing
=
use_indexing
self
.
_item_set
=
(
item_set
if
self
.
_use_indexing
else
IterableWrapper
(
item_set
)
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
557c0a86
...
...
@@ -388,6 +388,52 @@ def test_append_with_other_datapipes():
assert
len
(
data
)
==
batch_size
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSetDict_iterable_only
(
batch_size
,
shuffle
,
drop_last
):
class
IterableOnly
:
def
__init__
(
self
,
start
,
stop
):
self
.
_start
=
start
self
.
_stop
=
stop
def
__iter__
(
self
):
return
iter
(
torch
.
arange
(
self
.
_start
,
self
.
_stop
))
num_ids
=
205
ids
=
{
"user"
:
gb
.
ItemSet
(
IterableOnly
(
0
,
99
),
names
=
"seed_nodes"
),
"item"
:
gb
.
ItemSet
(
IterableOnly
(
99
,
num_ids
),
names
=
"seed_nodes"
),
}
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
ids
=
[]
for
_
,
v
in
minibatch
.
seed_nodes
.
items
():
ids
.
append
(
v
)
ids
=
torch
.
cat
(
ids
)
assert
len
(
ids
)
==
expected_batch_size
minibatch_ids
.
append
(
ids
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment