Unverified Commit 76d66fd3 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix dtype mismatch in EdgeDataLoader on Windows (#2188)

* fix node and edge dataloader on windows

* fix distributed
parent 782527d4
...@@ -8,6 +8,7 @@ from ..base import NID, EID ...@@ -8,6 +8,7 @@ from ..base import NID, EID
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from ..convert import heterograph from ..convert import heterograph
from ..distributed.dist_graph import DistGraph
# pylint: disable=unused-argument # pylint: disable=unused-argument
def assign_block_eids(block, frontier): def assign_block_eids(block, frontier):
...@@ -244,6 +245,7 @@ class BlockSampler(object): ...@@ -244,6 +245,7 @@ class BlockSampler(object):
assign_block_eids(block, frontier) assign_block_eids(block, frontier)
seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes} seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}
# Pre-generate CSR format so that it can be used in training directly # Pre-generate CSR format so that it can be used in training directly
block.create_formats_() block.create_formats_()
blocks.insert(0, block) blocks.insert(0, block)
...@@ -309,6 +311,7 @@ class NodeCollator(Collator): ...@@ -309,6 +311,7 @@ class NodeCollator(Collator):
""" """
def __init__(self, g, nids, block_sampler): def __init__(self, g, nids, block_sampler):
self.g = g self.g = g
self._is_distributed = isinstance(g, DistGraph)
if not isinstance(nids, Mapping): if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \ assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types" "nids should be a dict of node type and ids for graph with multiple node types"
...@@ -352,6 +355,15 @@ class NodeCollator(Collator): ...@@ -352,6 +355,15 @@ class NodeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
# TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
# this function does not work. I'm again skipping this step as a workaround.
# We need to fix this.
if not self._is_distributed:
if isinstance(items, dict):
items = utils.prepare_tensor_dict(self.g, items, 'items')
else:
items = utils.prepare_tensor(self.g, items, 'items')
blocks = self.block_sampler.sample_blocks(self.g, items) blocks = self.block_sampler.sample_blocks(self.g, items)
output_nodes = blocks[-1].dstdata[NID] output_nodes = blocks[-1].dstdata[NID]
input_nodes = blocks[0].srcdata[NID] input_nodes = blocks[0].srcdata[NID]
...@@ -559,10 +571,11 @@ class EdgeCollator(Collator): ...@@ -559,10 +571,11 @@ class EdgeCollator(Collator):
def _collate(self, items): def _collate(self, items):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()} items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
else: else:
items = F.zerocopy_from_numpy(np.asarray(items)) items = utils.prepare_tensor(self.g_sampling, items, 'items')
pair_graph = self.g.edge_subgraph(items) pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
...@@ -582,10 +595,11 @@ class EdgeCollator(Collator): ...@@ -582,10 +595,11 @@ class EdgeCollator(Collator):
def _collate_with_negative_sampling(self, items): def _collate_with_negative_sampling(self, items):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()} items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
else: else:
items = F.zerocopy_from_numpy(np.asarray(items)) items = utils.prepare_tensor(self.g_sampling, items, 'items')
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True) pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
induced_edges = pair_graph.edata[EID] induced_edges = pair_graph.edata[EID]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment