Unverified Commit 3f138eba authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Bug fixes in new dataloader (#3727)



* fixes

* fix

* more fixes

* update

* oops

* lint?

* temporarily revert - will fix in another PR

* more fixes

* skipping mxnet test

* address comments

* fix DDP

* fix edge dataloader exclusion problems

* stupid bug

* fix

* use_uvm option

* fix

* fixes

* fixes

* fixes

* fixes

* add evaluation for cluster gcn and ddp

* stupid bug again

* fixes

* move sanity checks to only support DGLGraphs

* pytorch lightning compatibility fixes

* remove

* poke

* more fixes

* fix

* fix

* disable test

* docstrings

* why is it getting a memory leak?

* fix

* update

* updates and temporarily disable forkingpickler

* update

* fix?

* fix?

* oops

* oops

* fix

* lint

* huh

* uh

* update

* fix

* made it memory efficient

* refine exclude interface

* fix tutorial

* fix tutorial

* fix graph duplication in CPU dataloader workers

* lint

* lint

* Revert "lint"

This reverts commit 805484dd553695111b5fb37f2125214a6b7276e9.

* Revert "lint"

This reverts commit 0bce411b2b415c2ab770343949404498436dc8b2.

* Revert "fix graph duplication in CPU dataloader workers"

This reverts commit 9e3a8cf34c175d3093c773f6bb023b155f2bd27f.
Co-authored-by: default avatarxiny <xiny@nvidia.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7b9afbfa
......@@ -359,6 +359,18 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::CSRMatrix();
}
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
adj_ = coo;
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
LOG(FATAL) << "Not enabled for COO graph";
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
LOG(FATAL) << "Not enabled for COO graph";
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::kCOO;
......@@ -779,6 +791,18 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return adj_;
}
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
LOG(FATAL) << "Not enabled for CSR graph";
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
adj_ = csr;
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
LOG(FATAL) << "Please use in_csr_->SetCSRMatrix(etype, csc) instead.";
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::kCSR;
......@@ -1243,7 +1267,7 @@ HeteroGraphPtr UnitGraph::CreateFromCSC(
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
}
......@@ -1488,6 +1512,54 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
return GetCOO()->adj();
}
void UnitGraph::SetCOOMatrix(dgl_type_t etype, COOMatrix coo) {
if (!(formats_ & COO_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set COO matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set COOMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!coo_->defined())
*(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), coo);
else
coo_->SetCOOMatrix(0, coo);
}
void UnitGraph::SetCSRMatrix(dgl_type_t etype, CSRMatrix csr) {
if (!(formats_ & CSR_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set CSR matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set CSRMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!out_csr_->defined())
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), csr);
else
out_csr_->SetCSRMatrix(0, csr);
}
void UnitGraph::SetCSCMatrix(dgl_type_t etype, CSRMatrix csc) {
if (!(formats_ & CSC_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set CSC matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set CSCMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!in_csr_->defined())
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), csc);
else
in_csr_->SetCSRMatrix(0, csc);
}
HeteroGraphPtr UnitGraph::GetAny() const {
if (in_csr_->defined()) {
return in_csr_;
......
......@@ -214,7 +214,7 @@ class UnitGraph : public BaseHeteroGraph {
* \brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDLCPUPinned: directly return;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
......@@ -223,7 +223,7 @@ class UnitGraph : public BaseHeteroGraph {
/*!
* \brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
* \note The graph will be unpinned inplace. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned;
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
......@@ -305,6 +305,10 @@ class UnitGraph : public BaseHeteroGraph {
void InvalidateCOO();
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override;
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override;
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override;
private:
friend class Serializer;
friend class HeteroGraph;
......
......@@ -187,6 +187,42 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaHostUnregister(ptr));
}
bool IsPinned(const void* ptr) override {
// can't be a pinned tensor if CUDA context is unavailable.
if (!is_available_)
return false;
cudaPointerAttributes attr;
cudaError_t status = cudaPointerGetAttributes(&attr, ptr);
bool result = false;
switch (status) {
case cudaErrorInvalidValue:
// might be a normal CPU tensor in CUDA 10.2-
cudaGetLastError(); // clear error
break;
case cudaSuccess:
result = (attr.type == cudaMemoryTypeHost);
break;
case cudaErrorInitializationError:
case cudaErrorNoDevice:
case cudaErrorInsufficientDriver:
case cudaErrorInvalidDevice:
// We don't want to fail in these particular cases since this function can be called
// when users only want to run on CPU even if CUDA API is enabled, or in a forked
// subprocess where CUDA context cannot be initialized. So we just mark the CUDA
// context to unavailable and return.
is_available_ = false;
cudaGetLastError(); // clear error
break;
default:
LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status);
break;
}
return result;
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
......@@ -213,6 +249,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamSynchronize(stream));
}
}
bool is_available_ = true;
};
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
......
......@@ -64,7 +64,7 @@ struct NDArray::Internal {
ptr->mem = nullptr;
} else if (ptr->dl_tensor.data != nullptr) {
// if the array is still pinned before freeing, unpin it.
if (ptr->dl_tensor.ctx.device_type == kDLCPUPinned) {
if (IsDataPinned(&(ptr->dl_tensor))) {
UnpinData(&(ptr->dl_tensor));
}
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
......@@ -206,19 +206,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
return ret;
}
inline DLContext GetDevice(DLContext ctx) {
switch (ctx.device_type) {
case kDLCPU:
case kDLGPU:
return ctx;
break;
default:
// fallback to CPU
return DLContext{kDLCPU, 0};
break;
}
}
NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
......@@ -226,7 +213,7 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
if (td->IsAvailable())
return td->Empty(shape, dtype, ctx);
NDArray ret = Internal::Create(shape, dtype, GetDevice(ctx));
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
......@@ -242,6 +229,7 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
data->deleter = Internal::DLPackDeleter;
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
return NDArray(data);
}
......@@ -260,7 +248,7 @@ void NDArray::CopyFromTo(DLTensor* from,
// Use the context that is *not* a cpu context to get the correct device
// api manager.
DGLContext ctx = GetDevice(from->ctx).device_type != kDLCPU ? from->ctx : to->ctx;
DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset),
......@@ -269,19 +257,15 @@ void NDArray::CopyFromTo(DLTensor* from,
}
void NDArray::PinData(DLTensor* tensor) {
// Only need to call PinData once, since the pinned memory can be seen
// by all CUDA contexts, not just the one that performed the allocation
if (tensor->ctx.device_type == kDLCPUPinned) return;
if (IsDataPinned(tensor)) return;
CHECK_EQ(tensor->ctx.device_type, kDLCPU)
<< "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDLGPU)->PinData(tensor->data, GetDataSize(*tensor));
tensor->ctx = DLContext{kDLCPUPinned, 0};
}
void NDArray::UnpinData(DLTensor* tensor) {
if (tensor->ctx.device_type != kDLCPUPinned) return;
if (!IsDataPinned(tensor)) return;
DeviceAPI::Get(kDLGPU)->UnpinData(tensor->data);
tensor->ctx = DLContext{kDLCPU, 0};
}
template<typename T>
......@@ -343,6 +327,14 @@ std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
return this->data_->mem;
}
bool NDArray::IsDataPinned(DLTensor* tensor) {
// Can only be pinned if on CPU...
if (tensor->ctx.device_type != kDLCPU)
return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned memory.
auto device = DeviceAPI::Get(kDLGPU, true);
return device && device->IsPinned(tensor->data);
}
void NDArray::Save(dmlc::Stream* strm) const {
auto zc_strm = dynamic_cast<StreamWithBuffer*>(strm);
......@@ -489,10 +481,9 @@ int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DLTensor* nd = &(nd_container->dl_tensor);
if ((alignment != 0 && !is_aligned(nd->data, alignment))
|| (nd->ctx.device_type == kDLCPUPinned)) {
if (alignment != 0 && !is_aligned(nd->data, alignment)) {
std::vector<int64_t> shape_vec(nd->shape, nd->shape + nd->ndim);
NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, GetDevice(nd->ctx));
NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, nd->ctx);
copy_ndarray.CopyFrom(nd);
*out = copy_ndarray.ToDLPack();
} else {
......
......@@ -12,6 +12,7 @@ import test_utils
from test_utils import parametrize_dtype, get_cases
from utils import assert_is_identical_hetero
from scipy.sparse import rand
import multiprocessing as mp
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
......@@ -206,6 +207,32 @@ def test_create(idtype):
assert g.device == F.cpu()
assert F.array_equal(g.edata['w'], F.copy_to(F.tensor(adj.data), F.cpu()))
def test_create2():
mat = ssp.random(20, 30, 0.1)
# coo
mat = mat.tocoo()
row = F.tensor(mat.row, dtype=F.int64)
col = F.tensor(mat.col, dtype=F.int64)
g = dgl.heterograph(
{('A', 'AB', 'B'): ('coo', (row, col))}, num_nodes_dict={'A': 20, 'B': 30})
# csr
mat = mat.tocsr()
indptr = F.tensor(mat.indptr, dtype=F.int64)
indices = F.tensor(mat.indices, dtype=F.int64)
data = F.tensor([], dtype=F.int64)
g = dgl.heterograph(
{('A', 'AB', 'B'): ('csr', (indptr, indices, data))}, num_nodes_dict={'A': 20, 'B': 30})
# csc
mat = mat.tocsc()
indptr = F.tensor(mat.indptr, dtype=F.int64)
indices = F.tensor(mat.indices, dtype=F.int64)
data = F.tensor([], dtype=F.int64)
g = dgl.heterograph(
{('A', 'AB', 'B'): ('csc', (indptr, indices, data))}, num_nodes_dict={'A': 20, 'B': 30})
@parametrize_dtype
def test_query(idtype):
g = create_test_heterograph(idtype)
......@@ -2796,6 +2823,24 @@ def test_adj_sparse(idtype, fmt):
assert np.array_equal(F.asnumpy(indices_sorted), indices_sorted_np)
def _test_forking_pickler_entry(g, q):
q.put(g.formats())
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="MXNet doesn't support spawning")
def test_forking_pickler():
ctx = mp.get_context('spawn')
g = dgl.graph(([0,1,2],[1,2,3]))
g.create_formats_()
q = ctx.Queue(1)
proc = ctx.Process(target=_test_forking_pickler_entry, args=(g, q))
proc.start()
fmt = q.get()['created']
proc.join()
assert 'coo' in fmt
assert 'csr' in fmt
assert 'csc' in fmt
if __name__ == '__main__':
# test_create()
# test_query()
......
import os
import numpy as np
import dgl
import dgl.ops as OPS
import backend as F
import unittest
import torch
from functools import partial
from torch.utils.data import DataLoader
from collections import defaultdict
from collections.abc import Iterator
from collections.abc import Iterator, Mapping
from itertools import product
import pytest
......@@ -99,7 +101,8 @@ def _check_device(data):
assert data.device == F.ctx()
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
@pytest.mark.parametrize('pin_graph', [True, False])
# TODO(BarclayII): Re-enable pin_graph = True after PyTorch is upgraded to 1.9.0 on CI
@pytest.mark.parametrize('pin_graph', [False])
def test_node_dataloader(sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
if F.ctx() != F.cpu() and pin_graph:
......@@ -153,7 +156,8 @@ def test_node_dataloader(sampler_name, pin_graph):
dgl.dataloading.negative_sampler.Uniform(2),
dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3)])
@pytest.mark.parametrize('pin_graph', [True, False])
# TODO(BarclayII): Re-enable pin_graph = True after PyTorch is upgraded to 1.9.0 on CI
@pytest.mark.parametrize('pin_graph', [False])
def test_edge_dataloader(sampler_name, neg_sampler, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
if F.ctx() != F.cpu() and pin_graph:
......@@ -222,15 +226,99 @@ def test_edge_dataloader(sampler_name, neg_sampler, pin_graph):
if g1.is_pinned():
g1.unpin_memory_()
def _create_homogeneous():
s = torch.randint(0, 200, (1000,), device=F.ctx())
d = torch.randint(0, 200, (1000,), device=F.ctx())
src = torch.cat([s, d])
dst = torch.cat([d, s])
g = dgl.graph((s, d), num_nodes=200)
reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)]).to(F.ctx())
always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
seed_edges = torch.arange(0, 1000, device=F.ctx())
return g, reverse_eids, always_exclude, seed_edges
def _create_heterogeneous():
edges = {}
for utype, etype, vtype in [('A', 'AA', 'A'), ('A', 'AB', 'B')]:
s = torch.randint(0, 200, (1000,), device=F.ctx())
d = torch.randint(0, 200, (1000,), device=F.ctx())
edges[utype, etype, vtype] = (s, d)
edges[vtype, 'rev-' + etype, utype] = (d, s)
g = dgl.heterograph(edges, num_nodes_dict={'A': 200, 'B': 200})
reverse_etypes = {'AA': 'rev-AA', 'AB': 'rev-AB', 'rev-AA': 'AA', 'rev-AB': 'AB'}
always_exclude = {
'AA': torch.randint(0, 1000, (50,), device=F.ctx()),
'AB': torch.randint(0, 1000, (50,), device=F.ctx())}
seed_edges = {
'AA': torch.arange(0, 1000, device=F.ctx()),
'AB': torch.arange(0, 1000, device=F.ctx())}
return g, reverse_etypes, always_exclude, seed_edges
def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
if exclude == None:
return always_exclude
elif exclude == 'self':
return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids
elif exclude == 'reverse_id':
pair_eids = torch.cat([pair_eids, pair_eids + 1000])
return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids
elif exclude == 'reverse_types':
pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
if ('A', 'AA', 'A') in pair_eids:
pair_eids[('A', 'rev-AA', 'A')] = pair_eids[('A', 'AA', 'A')]
if ('A', 'AB', 'B') in pair_eids:
pair_eids[('B', 'rev-AB', 'A')] = pair_eids[('A', 'AB', 'B')]
if always_exclude is not None:
always_exclude = {g.to_canonical_etype(k): v for k, v in always_exclude.items()}
for k in always_exclude.keys():
if k in pair_eids:
pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
else:
pair_eids[k] = always_exclude[k]
return pair_eids
@pytest.mark.parametrize('always_exclude_flag', [False, True])
@pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types'])
def test_edge_dataloader_excludes(exclude, always_exclude_flag):
if exclude == 'reverse_types':
g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
else:
g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
g = g.to(F.ctx())
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
if not always_exclude_flag:
always_exclude = None
kwargs = {}
kwargs['exclude'] = (
partial(_find_edges_to_exclude, g, exclude, always_exclude) if always_exclude_flag
else exclude)
kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None
kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None
dataloader = dgl.dataloading.EdgeDataLoader(
g, seed_edges, sampler, batch_size=50, device=F.ctx(), **kwargs)
for input_nodes, pair_graph, blocks in dataloader:
block = blocks[0]
pair_eids = pair_graph.edata[dgl.EID]
block_eids = block.edata[dgl.EID]
edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids)
if edges_to_exclude is None:
continue
edges_to_exclude = dgl.utils.recursive_apply(edges_to_exclude, lambda x: x.cpu().numpy())
block_eids = dgl.utils.recursive_apply(block_eids, lambda x: x.cpu().numpy())
if isinstance(edges_to_exclude, Mapping):
for k in edges_to_exclude.keys():
assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
else:
assert not np.isin(edges_to_exclude, block_eids).any()
if __name__ == '__main__':
test_graph_dataloader()
test_cluster_gcn(0)
test_neighbor_nonuniform(0)
for sampler in ['full', 'neighbor']:
test_node_dataloader(sampler)
for neg_sampler in [
dgl.dataloading.negative_sampler.Uniform(2),
dgl.dataloading.negative_sampler.GlobalUniform(2, False),
dgl.dataloading.negative_sampler.GlobalUniform(2, True)]:
for pin_graph in [True, False]:
test_edge_dataloader(sampler, neg_sampler, pin_graph)
for exclude in [None, 'self', 'reverse_id', 'reverse_types']:
test_edge_dataloader_excludes(exclude, False)
test_edge_dataloader_excludes(exclude, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment