"tests/python/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b08c446d66d6323efb42f75aa3dad104e6abbfc9"
Unverified Commit f758db38 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix dtype mismatch in heterogeneous DataLoader (#3878)

* fix

* unit test
parent e9632568
...@@ -21,7 +21,7 @@ from ..heterograph import DGLHeteroGraph ...@@ -21,7 +21,7 @@ from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd from .. import ndarray as nd
from ..utils import ( from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
create_shared_mem_array, get_shared_mem_array, context_of) create_shared_mem_array, get_shared_mem_array, context_of, dtype_of)
from ..frame import LazyFeature from ..frame import LazyFeature
from ..storages import wrap_storage from ..storages import wrap_storage
from .base import BlockSampler, as_edge_prediction_sampler from .base import BlockSampler, as_edge_prediction_sampler
...@@ -86,9 +86,11 @@ class _TensorizedDatasetIter(object): ...@@ -86,9 +86,11 @@ class _TensorizedDatasetIter(object):
def _get_id_tensor_from_mapping(indices, device, keys): def _get_id_tensor_from_mapping(indices, device, keys):
lengths = torch.LongTensor([ dtype = dtype_of(indices)
(indices[k].shape[0] if k in indices else 0) for k in keys]).to(device) lengths = torch.tensor(
type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths) [(indices[k].shape[0] if k in indices else 0) for k in keys],
dtype=dtype, device=device)
type_ids = torch.arange(len(keys), dtype=dtype, device=device).repeat_interleave(lengths)
all_indices = torch.cat([indices[k] for k in keys if k in indices]) all_indices = torch.cat([indices[k] for k in keys if k in indices])
return torch.stack([type_ids, all_indices], 1) return torch.stack([type_ids, all_indices], 1)
......
...@@ -1019,4 +1019,8 @@ def context_of(data): ...@@ -1019,4 +1019,8 @@ def context_of(data):
else: else:
return F.context(data) return F.context(data)
def dtype_of(data):
"""Return the dtype of the data which can be either a tensor or a dict of tensors."""
return F.dtype(next(iter(data.values())) if isinstance(data, Mapping) else data)
_init_api("dgl.utils.internal") _init_api("dgl.utils.internal")
...@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader ...@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator, Mapping from collections.abc import Iterator, Mapping
from itertools import product from itertools import product
from test_utils import parametrize_dtype
import pytest import pytest
...@@ -89,6 +90,15 @@ def test_neighbor_nonuniform(num_workers): ...@@ -89,6 +90,15 @@ def test_neighbor_nonuniform(num_workers):
elif seed == 0: elif seed == 0:
assert neighbors == {3, 4} assert neighbors == {3, 4}
def _check_dtype(data, dtype, attr_name):
if isinstance(data, dict):
for k, v in data.items():
assert getattr(v, attr_name) == dtype
elif isinstance(data, list):
for v in data:
assert getattr(v, attr_name) == dtype
else:
assert getattr(data, attr_name) == dtype
def _check_device(data): def _check_device(data):
if isinstance(data, dict): if isinstance(data, dict):
...@@ -100,10 +110,11 @@ def _check_device(data): ...@@ -100,10 +110,11 @@ def _check_device(data):
else: else:
assert data.device == F.ctx() assert data.device == F.ctx()
@parametrize_dtype
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2']) @pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
@pytest.mark.parametrize('pin_graph', [False, True]) @pytest.mark.parametrize('pin_graph', [False, True])
def test_node_dataloader(sampler_name, pin_graph): def test_node_dataloader(idtype, sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])) g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
if F.ctx() != F.cpu() and pin_graph: if F.ctx() != F.cpu() and pin_graph:
g1.create_formats_() g1.create_formats_()
g1.pin_memory_() g1.pin_memory_()
...@@ -123,13 +134,16 @@ def test_node_dataloader(sampler_name, pin_graph): ...@@ -123,13 +134,16 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device(input_nodes) _check_device(input_nodes)
_check_device(output_nodes) _check_device(output_nodes)
_check_device(blocks) _check_device(blocks)
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')
g2 = dgl.heterograph({ g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]), ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]), ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]), ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]) ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
}) }).astype(idtype)
for ntype in g2.ntypes: for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu()) g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes) batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
...@@ -146,6 +160,9 @@ def test_node_dataloader(sampler_name, pin_graph): ...@@ -146,6 +160,9 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device(input_nodes) _check_device(input_nodes)
_check_device(output_nodes) _check_device(output_nodes)
_check_device(blocks) _check_device(blocks)
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')
if g1.is_pinned(): if g1.is_pinned():
g1.unpin_memory_() g1.unpin_memory_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment