Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
76d66fd3
Unverified
Commit
76d66fd3
authored
Sep 13, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Sep 13, 2020
Browse files
[Bug] Fix dtype mismatch in EdgeDataLoader on Windows (#2188)
* fix node and edge dataloader on windows * fix distributed
parent
782527d4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
4 deletions
+18
-4
python/dgl/dataloading/dataloader.py
python/dgl/dataloading/dataloader.py
+18
-4
No files found.
python/dgl/dataloading/dataloader.py
View file @
76d66fd3
...
@@ -8,6 +8,7 @@ from ..base import NID, EID
...
@@ -8,6 +8,7 @@ from ..base import NID, EID
from
..
import
backend
as
F
from
..
import
backend
as
F
from
..
import
utils
from
..
import
utils
from
..convert
import
heterograph
from
..convert
import
heterograph
from
..distributed.dist_graph
import
DistGraph
# pylint: disable=unused-argument
# pylint: disable=unused-argument
def
assign_block_eids
(
block
,
frontier
):
def
assign_block_eids
(
block
,
frontier
):
...
@@ -244,6 +245,7 @@ class BlockSampler(object):
...
@@ -244,6 +245,7 @@ class BlockSampler(object):
assign_block_eids
(
block
,
frontier
)
assign_block_eids
(
block
,
frontier
)
seed_nodes
=
{
ntype
:
block
.
srcnodes
[
ntype
].
data
[
NID
]
for
ntype
in
block
.
srctypes
}
seed_nodes
=
{
ntype
:
block
.
srcnodes
[
ntype
].
data
[
NID
]
for
ntype
in
block
.
srctypes
}
# Pre-generate CSR format so that it can be used in training directly
# Pre-generate CSR format so that it can be used in training directly
block
.
create_formats_
()
block
.
create_formats_
()
blocks
.
insert
(
0
,
block
)
blocks
.
insert
(
0
,
block
)
...
@@ -309,6 +311,7 @@ class NodeCollator(Collator):
...
@@ -309,6 +311,7 @@ class NodeCollator(Collator):
"""
"""
def
__init__
(
self
,
g
,
nids
,
block_sampler
):
def
__init__
(
self
,
g
,
nids
,
block_sampler
):
self
.
g
=
g
self
.
g
=
g
self
.
_is_distributed
=
isinstance
(
g
,
DistGraph
)
if
not
isinstance
(
nids
,
Mapping
):
if
not
isinstance
(
nids
,
Mapping
):
assert
len
(
g
.
ntypes
)
==
1
,
\
assert
len
(
g
.
ntypes
)
==
1
,
\
"nids should be a dict of node type and ids for graph with multiple node types"
"nids should be a dict of node type and ids for graph with multiple node types"
...
@@ -352,6 +355,15 @@ class NodeCollator(Collator):
...
@@ -352,6 +355,15 @@ class NodeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
group_as_dict
(
items
)
# TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
# this function does not work. I'm again skipping this step as a workaround.
# We need to fix this.
if
not
self
.
_is_distributed
:
if
isinstance
(
items
,
dict
):
items
=
utils
.
prepare_tensor_dict
(
self
.
g
,
items
,
'items'
)
else
:
items
=
utils
.
prepare_tensor
(
self
.
g
,
items
,
'items'
)
blocks
=
self
.
block_sampler
.
sample_blocks
(
self
.
g
,
items
)
blocks
=
self
.
block_sampler
.
sample_blocks
(
self
.
g
,
items
)
output_nodes
=
blocks
[
-
1
].
dstdata
[
NID
]
output_nodes
=
blocks
[
-
1
].
dstdata
[
NID
]
input_nodes
=
blocks
[
0
].
srcdata
[
NID
]
input_nodes
=
blocks
[
0
].
srcdata
[
NID
]
...
@@ -559,10 +571,11 @@ class EdgeCollator(Collator):
...
@@ -559,10 +571,11 @@ class EdgeCollator(Collator):
def
_collate
(
self
,
items
):
def
_collate
(
self
,
items
):
if
isinstance
(
items
[
0
],
tuple
):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
group_as_dict
(
items
)
items
=
{
k
:
F
.
zerocopy_from_numpy
(
np
.
asarray
(
v
))
for
k
,
v
in
items
.
items
()}
items
=
utils
.
prepare_tensor_dict
(
self
.
g_sampling
,
items
,
'
items
'
)
else
:
else
:
items
=
F
.
zerocopy_from_numpy
(
np
.
asarray
(
items
)
)
items
=
utils
.
prepare_tensor
(
self
.
g_sampling
,
items
,
'
items
'
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
)
seed_nodes
=
pair_graph
.
ndata
[
NID
]
seed_nodes
=
pair_graph
.
ndata
[
NID
]
...
@@ -582,10 +595,11 @@ class EdgeCollator(Collator):
...
@@ -582,10 +595,11 @@ class EdgeCollator(Collator):
def
_collate_with_negative_sampling
(
self
,
items
):
def
_collate_with_negative_sampling
(
self
,
items
):
if
isinstance
(
items
[
0
],
tuple
):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
group_as_dict
(
items
)
items
=
{
k
:
F
.
zerocopy_from_numpy
(
np
.
asarray
(
v
))
for
k
,
v
in
items
.
items
()}
items
=
utils
.
prepare_tensor_dict
(
self
.
g_sampling
,
items
,
'
items
'
)
else
:
else
:
items
=
F
.
zerocopy_from_numpy
(
np
.
asarray
(
items
)
)
items
=
utils
.
prepare_tensor
(
self
.
g_sampling
,
items
,
'
items
'
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
,
preserve_nodes
=
True
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
,
preserve_nodes
=
True
)
induced_edges
=
pair_graph
.
edata
[
EID
]
induced_edges
=
pair_graph
.
edata
[
EID
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment