Unverified Commit f758db38 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix dtype mismatch in heterogeneous DataLoader (#3878)

* fix

* unit test
parent e9632568
......@@ -21,7 +21,7 @@ from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd
from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
create_shared_mem_array, get_shared_mem_array, context_of)
create_shared_mem_array, get_shared_mem_array, context_of, dtype_of)
from ..frame import LazyFeature
from ..storages import wrap_storage
from .base import BlockSampler, as_edge_prediction_sampler
......@@ -86,9 +86,11 @@ class _TensorizedDatasetIter(object):
def _get_id_tensor_from_mapping(indices, device, keys):
lengths = torch.LongTensor([
(indices[k].shape[0] if k in indices else 0) for k in keys]).to(device)
type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths)
dtype = dtype_of(indices)
lengths = torch.tensor(
[(indices[k].shape[0] if k in indices else 0) for k in keys],
dtype=dtype, device=device)
type_ids = torch.arange(len(keys), dtype=dtype, device=device).repeat_interleave(lengths)
all_indices = torch.cat([indices[k] for k in keys if k in indices])
return torch.stack([type_ids, all_indices], 1)
......
......@@ -1019,4 +1019,8 @@ def context_of(data):
else:
return F.context(data)
def dtype_of(data):
"""Return the dtype of the data which can be either a tensor or a dict of tensors."""
return F.dtype(next(iter(data.values())) if isinstance(data, Mapping) else data)
_init_api("dgl.utils.internal")
......@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
from collections import defaultdict
from collections.abc import Iterator, Mapping
from itertools import product
from test_utils import parametrize_dtype
import pytest
......@@ -89,6 +90,15 @@ def test_neighbor_nonuniform(num_workers):
elif seed == 0:
assert neighbors == {3, 4}
def _check_dtype(data, dtype, attr_name):
if isinstance(data, dict):
for k, v in data.items():
assert getattr(v, attr_name) == dtype
elif isinstance(data, list):
for v in data:
assert getattr(v, attr_name) == dtype
else:
assert getattr(data, attr_name) == dtype
def _check_device(data):
if isinstance(data, dict):
......@@ -100,10 +110,11 @@ def _check_device(data):
else:
assert data.device == F.ctx()
@parametrize_dtype
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
@pytest.mark.parametrize('pin_graph', [False, True])
def test_node_dataloader(sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
def test_node_dataloader(idtype, sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
if F.ctx() != F.cpu() and pin_graph:
g1.create_formats_()
g1.pin_memory_()
......@@ -123,13 +134,16 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
}).astype(idtype)
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
......@@ -146,6 +160,9 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')
if g1.is_pinned():
g1.unpin_memory_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment