Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f758db38
Unverified
Commit
f758db38
authored
Mar 26, 2022
by
Quan (Andy) Gan
Committed by
GitHub
Mar 26, 2022
Browse files
[Bug] Fix dtype mismatch in heterogeneous DataLoader (#3878)
* fix * unit test
parent
e9632568
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
7 deletions
+30
-7
python/dgl/dataloading/dataloader.py
python/dgl/dataloading/dataloader.py
+6
-4
python/dgl/utils/internal.py
python/dgl/utils/internal.py
+4
-0
tests/pytorch/test_dataloader.py
tests/pytorch/test_dataloader.py
+20
-3
No files found.
python/dgl/dataloading/dataloader.py
View file @
f758db38
...
...
@@ -21,7 +21,7 @@ from ..heterograph import DGLHeteroGraph
from
..
import
ndarray
as
nd
from
..utils
import
(
recursive_apply
,
ExceptionWrapper
,
recursive_apply_pair
,
set_num_threads
,
create_shared_mem_array
,
get_shared_mem_array
,
context_of
)
create_shared_mem_array
,
get_shared_mem_array
,
context_of
,
dtype_of
)
from
..frame
import
LazyFeature
from
..storages
import
wrap_storage
from
.base
import
BlockSampler
,
as_edge_prediction_sampler
...
...
@@ -86,9 +86,11 @@ class _TensorizedDatasetIter(object):
def
_get_id_tensor_from_mapping
(
indices
,
device
,
keys
):
lengths
=
torch
.
LongTensor
([
(
indices
[
k
].
shape
[
0
]
if
k
in
indices
else
0
)
for
k
in
keys
]).
to
(
device
)
type_ids
=
torch
.
arange
(
len
(
keys
),
device
=
device
).
repeat_interleave
(
lengths
)
dtype
=
dtype_of
(
indices
)
lengths
=
torch
.
tensor
(
[(
indices
[
k
].
shape
[
0
]
if
k
in
indices
else
0
)
for
k
in
keys
],
dtype
=
dtype
,
device
=
device
)
type_ids
=
torch
.
arange
(
len
(
keys
),
dtype
=
dtype
,
device
=
device
).
repeat_interleave
(
lengths
)
all_indices
=
torch
.
cat
([
indices
[
k
]
for
k
in
keys
if
k
in
indices
])
return
torch
.
stack
([
type_ids
,
all_indices
],
1
)
...
...
python/dgl/utils/internal.py
View file @
f758db38
...
...
@@ -1019,4 +1019,8 @@ def context_of(data):
else
:
return
F
.
context
(
data
)
def
dtype_of
(
data
):
"""Return the dtype of the data which can be either a tensor or a dict of tensors."""
return
F
.
dtype
(
next
(
iter
(
data
.
values
()))
if
isinstance
(
data
,
Mapping
)
else
data
)
_init_api
(
"dgl.utils.internal"
)
tests/pytorch/test_dataloader.py
View file @
f758db38
...
...
@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
from
collections
import
defaultdict
from
collections.abc
import
Iterator
,
Mapping
from
itertools
import
product
from
test_utils
import
parametrize_dtype
import
pytest
...
...
@@ -89,6 +90,15 @@ def test_neighbor_nonuniform(num_workers):
elif
seed
==
0
:
assert
neighbors
==
{
3
,
4
}
def
_check_dtype
(
data
,
dtype
,
attr_name
):
if
isinstance
(
data
,
dict
):
for
k
,
v
in
data
.
items
():
assert
getattr
(
v
,
attr_name
)
==
dtype
elif
isinstance
(
data
,
list
):
for
v
in
data
:
assert
getattr
(
v
,
attr_name
)
==
dtype
else
:
assert
getattr
(
data
,
attr_name
)
==
dtype
def
_check_device
(
data
):
if
isinstance
(
data
,
dict
):
...
...
@@ -100,10 +110,11 @@ def _check_device(data):
else
:
assert
data
.
device
==
F
.
ctx
()
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'sampler_name'
,
[
'full'
,
'neighbor'
,
'neighbor2'
])
@
pytest
.
mark
.
parametrize
(
'pin_graph'
,
[
False
,
True
])
def
test_node_dataloader
(
sampler_name
,
pin_graph
):
g1
=
dgl
.
graph
(([
0
,
0
,
0
,
1
,
1
],
[
1
,
2
,
3
,
3
,
4
]))
def
test_node_dataloader
(
idtype
,
sampler_name
,
pin_graph
):
g1
=
dgl
.
graph
(([
0
,
0
,
0
,
1
,
1
],
[
1
,
2
,
3
,
3
,
4
]))
.
astype
(
idtype
)
if
F
.
ctx
()
!=
F
.
cpu
()
and
pin_graph
:
g1
.
create_formats_
()
g1
.
pin_memory_
()
...
...
@@ -123,13 +134,16 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device
(
input_nodes
)
_check_device
(
output_nodes
)
_check_device
(
blocks
)
_check_dtype
(
input_nodes
,
idtype
,
'dtype'
)
_check_dtype
(
output_nodes
,
idtype
,
'dtype'
)
_check_dtype
(
blocks
,
idtype
,
'idtype'
)
g2
=
dgl
.
heterograph
({
(
'user'
,
'follow'
,
'user'
):
([
0
,
0
,
0
,
1
,
1
,
1
,
2
],
[
1
,
2
,
3
,
0
,
2
,
3
,
0
]),
(
'user'
,
'followed-by'
,
'user'
):
([
1
,
2
,
3
,
0
,
2
,
3
,
0
],
[
0
,
0
,
0
,
1
,
1
,
1
,
2
]),
(
'user'
,
'play'
,
'game'
):
([
0
,
1
,
1
,
3
,
5
],
[
0
,
1
,
2
,
0
,
2
]),
(
'game'
,
'played-by'
,
'user'
):
([
0
,
1
,
2
,
0
,
2
],
[
0
,
1
,
1
,
3
,
5
])
})
})
.
astype
(
idtype
)
for
ntype
in
g2
.
ntypes
:
g2
.
nodes
[
ntype
].
data
[
'feat'
]
=
F
.
copy_to
(
F
.
randn
((
g2
.
num_nodes
(
ntype
),
8
)),
F
.
cpu
())
batch_size
=
max
(
g2
.
num_nodes
(
nty
)
for
nty
in
g2
.
ntypes
)
...
...
@@ -146,6 +160,9 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device
(
input_nodes
)
_check_device
(
output_nodes
)
_check_device
(
blocks
)
_check_dtype
(
input_nodes
,
idtype
,
'dtype'
)
_check_dtype
(
output_nodes
,
idtype
,
'dtype'
)
_check_dtype
(
blocks
,
idtype
,
'idtype'
)
if
g1
.
is_pinned
():
g1
.
unpin_memory_
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment