Unverified Commit 76d66fd3 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix dtype mismatch in EdgeDataLoader on Windows (#2188)

* fix node and edge dataloader on windows

* fix distributed
parent 782527d4
......@@ -8,6 +8,7 @@ from ..base import NID, EID
from .. import backend as F
from .. import utils
from ..convert import heterograph
from ..distributed.dist_graph import DistGraph
# pylint: disable=unused-argument
def assign_block_eids(block, frontier):
......@@ -244,6 +245,7 @@ class BlockSampler(object):
assign_block_eids(block, frontier)
seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}
# Pre-generate CSR format so that it can be used in training directly
block.create_formats_()
blocks.insert(0, block)
......@@ -309,6 +311,7 @@ class NodeCollator(Collator):
"""
def __init__(self, g, nids, block_sampler):
self.g = g
self._is_distributed = isinstance(g, DistGraph)
if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types"
......@@ -352,6 +355,15 @@ class NodeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
# TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
# this function does not work. I'm again skipping this step as a workaround.
# We need to fix this.
if not self._is_distributed:
if isinstance(items, dict):
items = utils.prepare_tensor_dict(self.g, items, 'items')
else:
items = utils.prepare_tensor(self.g, items, 'items')
blocks = self.block_sampler.sample_blocks(self.g, items)
output_nodes = blocks[-1].dstdata[NID]
input_nodes = blocks[0].srcdata[NID]
......@@ -559,10 +571,11 @@ class EdgeCollator(Collator):
def _collate(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()}
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
else:
items = F.zerocopy_from_numpy(np.asarray(items))
items = utils.prepare_tensor(self.g_sampling, items, 'items')
pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID]
......@@ -582,10 +595,11 @@ class EdgeCollator(Collator):
def _collate_with_negative_sampling(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()}
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
else:
items = F.zerocopy_from_numpy(np.asarray(items))
items = utils.prepare_tensor(self.g_sampling, items, 'items')
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
induced_edges = pair_graph.edata[EID]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment