"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "210a07b13cb7589005e1e3b6368e5deb0d815b66"
Unverified Commit f6349508 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance][Feature] Implement edge excluding in EdgeDataLoader on GPU (#3226)



* Update filter code

* Add unit tests

* Fixes

* Switch to indices

* Rename functions

* Fix linting

* Fix whitespace

* Add doc

* Fix heterograph

* Change workspace allocation

* Fix linting

* Fix docs in filter.py

* Add todo
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent fc6f0b9e
...@@ -46,6 +46,83 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): ...@@ -46,6 +46,83 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
result = np.isin(frontier_parent_eids, exclude_eids).nonzero()[0] result = np.isin(frontier_parent_eids, exclude_eids).nonzero()[0]
return F.zerocopy_from_numpy(result) return F.zerocopy_from_numpy(result)
class _EidExcluder():
def __init__(self, exclude_eids):
device = None
if isinstance(exclude_eids, Mapping):
for _, v in exclude_eids.items():
if device is None:
device = F.context(v)
break
else:
device = F.context(exclude_eids)
self._exclude_eids = None
self._filter = None
if device == F.cpu():
# TODO(nv-dlasalle): Once Filter is implemented for the CPU, we
# should just use that irregardless of the device.
self._exclude_eids = (
_tensor_or_dict_to_numpy(exclude_eids) if exclude_eids is not None else None)
else:
if isinstance(exclude_eids, Mapping):
self._filter = {k: utils.Filter(v) for k, v in exclude_eids.items()}
else:
self._filter = utils.Filter(exclude_eids)
def _find_indices(self, parent_eids):
""" Find the set of edge indices to remove.
"""
if self._exclude_eids is not None:
parent_eids_np = _tensor_or_dict_to_numpy(parent_eids)
return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)
else:
assert self._filter is not None
if isinstance(parent_eids, Mapping):
located_eids = {k: self._filter[k].find_included_indices(parent_eids[k])
for k, v in parent_eids.items()}
else:
located_eids = self._filter.find_included_indices(parent_eids)
return located_eids
def __call__(self, frontier):
parent_eids = frontier.edata[EID]
located_eids = self._find_indices(parent_eids)
if not isinstance(located_eids, Mapping):
# (BarclayII) If frontier already has a EID field and located_eids is empty,
# the returned graph will keep EID intact. Otherwise, EID will change
# to the mapping from the new graph to the old frontier.
# So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0:
frontier = transform.remove_edges(
frontier, located_eids, store_ids=True)
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID])
else:
# (BarclayII) remove_edges only accepts removing one type of edges,
# so I need to keep track of the edge IDs left one by one.
new_eids = parent_eids.copy()
for k, v in located_eids.items():
if len(v) > 0:
frontier = transform.remove_edges(
frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids
def _create_eid_excluder(exclude_eids, device):
if exclude_eids is None:
return None
if device is not None:
if isinstance(exclude_eids, Mapping):
exclude_eids = {k: F.copy_to(v, device) \
for k, v in exclude_eids.items()}
else:
exclude_eids = F.copy_to(exclude_eids, device)
return _EidExcluder(exclude_eids)
def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
if isinstance(eids, Mapping): if isinstance(eids, Mapping):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()} eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
...@@ -230,8 +307,7 @@ class BlockSampler(object): ...@@ -230,8 +307,7 @@ class BlockSampler(object):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
blocks = [] blocks = []
exclude_eids = ( eid_excluder = _create_eid_excluder(exclude_eids, self.output_device)
_tensor_or_dict_to_numpy(exclude_eids) if exclude_eids is not None else None)
if isinstance(g, DistGraph): if isinstance(g, DistGraph):
# TODO:(nv-dlasalle) dist graphs may not have an associated graph, # TODO:(nv-dlasalle) dist graphs may not have an associated graph,
...@@ -250,32 +326,6 @@ class BlockSampler(object): ...@@ -250,32 +326,6 @@ class BlockSampler(object):
seed_nodes_in = seed_nodes_in.to(graph_device) seed_nodes_in = seed_nodes_in.to(graph_device)
frontier = self.sample_frontier(block_id, g, seed_nodes_in) frontier = self.sample_frontier(block_id, g, seed_nodes_in)
# Removing edges from the frontier for link prediction training falls
# into the category of frontier postprocessing
if exclude_eids is not None:
parent_eids = frontier.edata[EID]
parent_eids_np = _tensor_or_dict_to_numpy(parent_eids)
located_eids = _locate_eids_to_exclude(parent_eids_np, exclude_eids)
if not isinstance(located_eids, Mapping):
# (BarclayII) If frontier already has a EID field and located_eids is empty,
# the returned graph will keep EID intact. Otherwise, EID will change
# to the mapping from the new graph to the old frontier.
# So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0:
frontier = transform.remove_edges(
frontier, located_eids, store_ids=True)
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID])
else:
# (BarclayII) remove_edges only accepts removing one type of edges,
# so I need to keep track of the edge IDs left one by one.
new_eids = parent_eids.copy()
for k, v in located_eids.items():
if len(v) > 0:
frontier = transform.remove_edges(
frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids
if self.output_device is not None: if self.output_device is not None:
frontier = frontier.to(self.output_device) frontier = frontier.to(self.output_device)
if isinstance(seed_nodes, dict): if isinstance(seed_nodes, dict):
...@@ -286,6 +336,11 @@ class BlockSampler(object): ...@@ -286,6 +336,11 @@ class BlockSampler(object):
else: else:
seed_nodes_out = seed_nodes seed_nodes_out = seed_nodes
# Removing edges from the frontier for link prediction training falls
# into the category of frontier postprocessing
if eid_excluder is not None:
eid_excluder(frontier)
block = transform.to_block(frontier, seed_nodes_out) block = transform.to_block(frontier, seed_nodes_out)
if self.return_eids: if self.return_eids:
assign_block_eids(block, frontier) assign_block_eids(block, frontier)
......
...@@ -3,3 +3,4 @@ from .internal import * ...@@ -3,3 +3,4 @@ from .internal import *
from .data import * from .data import *
from .checks import * from .checks import *
from .shared_mem import * from .shared_mem import *
from .filter import *
"""Utilities for finding overlap or missing items in arrays."""
from .._ffi.function import _init_api
from .. import backend as F
class Filter(object):
"""Class used to either find the subset of IDs that are in this
filter, or the subset of IDs that are not in this filter
given a second set of IDs.
Examples
--------
>>> import torch as th
>>> from dgl.utils import Filter
>>> f = Filter(th.tensor([3,2,9], device=th.device('cuda')))
>>> f.find_included_indices(th.tensor([0,2,8,9], device=th.device('cuda')))
tensor([1,3])
>>> f.find_excluded_indices(th.tensor([0,2,8,9], device=th.device('cuda')))
tensor([0,2], device='cuda')
"""
def __init__(self, ids):
"""Create a new filter from a given set of IDs. This currently is only
implemented for the GPU.
Parameters
----------
ids : IdArray
The unique set of IDs to keep in the filter.
"""
self._filter = _CAPI_DGLFilterCreateFromSet(
F.zerocopy_to_dgl_ndarray(ids))
def find_included_indices(self, test):
"""Find the index of the IDs in `test` that are in this filter.
Parameters
----------
test : IdArray
The set of IDs to to test with.
Returns
-------
IdArray
The index of IDs in `test` that are also in this filter.
"""
return F.zerocopy_from_dgl_ndarray( \
_CAPI_DGLFilterFindIncludedIndices( \
self._filter, F.zerocopy_to_dgl_ndarray(test)))
def find_excluded_indices(self, test):
"""Find the index of the IDs in `test` that are not in this filter.
Parameters
----------
test : IdArray
The set of IDs to to test with.
Returns
-------
IdArray
The index of IDs in `test` that are not in this filter.
"""
return F.zerocopy_from_dgl_ndarray( \
_CAPI_DGLFilterFindExcludedIndices( \
self._filter, F.zerocopy_to_dgl_ndarray(test)))
_init_api("dgl.utils.filter")
/*!
* Copyright (c) 2021 by Contributors
* \file array/cuda/cuda_filter.cc
* \brief Object for selecting items in a set, or selecting items not in a set.
*/
#include <dgl/runtime/device_api.h>
#include "../filter.h"
#include "../../runtime/cuda/cuda_hashtable.cuh"
#include "./dgl_cub.cuh"
namespace dgl {
namespace array {
using namespace dgl::runtime::cuda;
namespace {
// TODO(nv-dlasalle): Replace with getting the stream from the context
// when it's implemented.
constexpr cudaStream_t cudaDefaultStream = 0;
template<typename IdType, bool include>
__global__ void _IsInKernel(
DeviceOrderedHashTable<IdType> table,
const IdType * const array,
const int64_t size,
IdType * const mark) {
const int64_t idx = threadIdx.x + blockDim.x*blockIdx.x;
if (idx < size) {
mark[idx] = table.Contains(array[idx]) ^ (!include);
}
}
template<typename IdType>
__global__ void _InsertKernel(
const IdType * const prefix,
const int64_t size,
IdType * const result) {
const int64_t idx = threadIdx.x + blockDim.x*blockIdx.x;
if (idx < size) {
if (prefix[idx] != prefix[idx+1]) {
result[prefix[idx]] = idx;
}
}
}
template<typename IdType, bool include>
IdArray _PerformFilter(
const OrderedHashTable<IdType>& table,
IdArray test) {
const auto& ctx = test->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t size = test->shape[0];
if (size == 0) {
return test;
}
cudaStream_t stream = cudaDefaultStream;
// we need two arrays: 1) to act as a prefixsum
// for the number of entries that will be inserted, and
// 2) to collect the included items.
IdType * prefix = static_cast<IdType*>(
device->AllocWorkspace(ctx, sizeof(IdType)*(size+1)));
// will resize down later
IdArray result = aten::NewIdArray(size, ctx, sizeof(IdType)*8);
// mark each index based on it's existence in the hashtable
{
const dim3 block(256);
const dim3 grid((size+block.x-1)/block.x);
_IsInKernel<IdType, include><<<grid, block, 0, stream>>>(
table.DeviceHandle(),
static_cast<const IdType*>(test->data),
size,
prefix);
CUDA_CALL(cudaGetLastError());
}
// generate prefix-sum
{
size_t workspace_bytes;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr,
workspace_bytes,
static_cast<IdType*>(nullptr),
static_cast<IdType*>(nullptr),
size+1));
void * workspace = device->AllocWorkspace(ctx, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
workspace,
workspace_bytes,
prefix,
prefix,
size+1, stream));
device->FreeWorkspace(ctx, workspace);
}
// copy number
IdType num_unique;
device->CopyDataFromTo(prefix+size, 0,
&num_unique, 0,
sizeof(num_unique),
ctx,
DGLContext{kDLCPU, 0},
test->dtype,
stream);
// insert items into set
{
const dim3 block(256);
const dim3 grid((size+block.x-1)/block.x);
_InsertKernel<<<grid, block, 0, stream>>>(
prefix,
size,
static_cast<IdType*>(result->data));
CUDA_CALL(cudaGetLastError());
}
device->FreeWorkspace(ctx, prefix);
return result.CreateView({num_unique}, result->dtype);
}
template<typename IdType>
class CudaFilterSet : public Filter {
public:
explicit CudaFilterSet(IdArray array) :
table_(array->shape[0], array->ctx, cudaDefaultStream) {
table_.FillWithUnique(
static_cast<const IdType*>(array->data),
array->shape[0],
cudaDefaultStream);
}
IdArray find_included_indices(IdArray test) override {
return _PerformFilter<IdType, true>(table_, test);
}
IdArray find_excluded_indices(IdArray test) override {
return _PerformFilter<IdType, false>(table_, test);
}
private:
OrderedHashTable<IdType> table_;
};
} // namespace
template<DLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
}
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set);
} // namespace array
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file array/filter.cc
* \brief Object for selecting items in a set, or selecting items not in a set.
*/
#include "./filter.h"
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/packed_func_ext.h>
namespace dgl {
namespace array {
using namespace dgl::runtime;
template<DLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set);
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
IdArray array = args[0];
auto ctx = array->ctx;
// TODO(nv-dlasalle): Implement CPU version.
if (ctx.device_type == kDLGPU) {
#ifdef DGL_USE_CUDA
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
*rv = CreateSetFilter<kDLGPU, IdType>(array);
});
#else
LOG(FATAL) << "GPU support not compiled.";
#endif
} else {
LOG(FATAL) << "CPU support not yet implemented.";
}
});
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterFindIncludedIndices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
FilterRef filter = args[0];
IdArray array = args[1];
*rv = filter->find_included_indices(array);
});
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterFindExcludedIndices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
FilterRef filter = args[0];
IdArray array = args[1];
*rv = filter->find_excluded_indices(array);
});
} // namespace array
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file array/filter.h
* \brief Object for selecting items in a set, or selecting items not in a set.
*/
#ifndef DGL_ARRAY_FILTER_H_
#define DGL_ARRAY_FILTER_H_
#include <dgl/runtime/object.h>
#include <dgl/array.h>
namespace dgl {
namespace array {
class Filter : public runtime::Object {
public:
static constexpr const char* _type_key = "array.Filter";
DGL_DECLARE_OBJECT_TYPE_INFO(Filter, Object);
/**
* @brief From the test set of items, get the index of those which are
* included by this filter.
*
* @param test The set of items to check for.
*
* @return The indices of the items from `test` that are selected by
* this filter.
*/
virtual IdArray find_included_indices(
IdArray test) = 0;
/**
* @brief From the test set of items, get the indices of those which are
* excluded by this filter.
*
* @param test The set of items to check for.
*
* @return The indices of the items from `test` that are not selected by this
* filter.
*/
virtual IdArray find_excluded_indices(
IdArray test) = 0;
};
DGL_DEFINE_OBJECT_REF(FilterRef, Filter);
} // namespace array
} // namespace dgl
#endif // DGL_ARRAY_FILTER_H_
...@@ -108,6 +108,28 @@ class DeviceOrderedHashTable { ...@@ -108,6 +108,28 @@ class DeviceOrderedHashTable {
return &table_[pos]; return &table_[pos];
} }
/**
* \brief Check whether a key exists within the hashtable.
*
* \param id The key to check for.
*
* \return True if the key exists in the hashtable.
*/
inline __device__ bool Contains(
const IdType id) const {
IdType pos = Hash(id);
IdType delta = 1;
while (table_[pos].key != kEmptyKey) {
if (table_[pos].key == id) {
return true;
}
pos = Hash(pos+delta);
delta +=1;
}
return false;
}
protected: protected:
// Must be uniform bytes for memset to work // Must be uniform bytes for memset to work
static constexpr IdType kEmptyKey = static_cast<IdType>(-1); static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
......
import dgl import dgl
import backend as F import backend as F
import numpy as np import numpy as np
import unittest
from test_utils import parametrize_dtype
from dgl.utils import Filter
def test_filter(): def test_graph_filter():
g = dgl.DGLGraph().to(F.ctx()) g = dgl.DGLGraph().to(F.ctx())
g.add_nodes(4) g.add_nodes(4)
g.add_edges([0,1,2,3], [1,2,3,0]) g.add_edges([0,1,2,3], [1,2,3,0])
...@@ -36,6 +39,28 @@ def test_filter(): ...@@ -36,6 +39,28 @@ def test_filter():
e_idx = g.filter_edges(predicate, [0, 1]) e_idx = g.filter_edges(predicate, [0, 1])
assert set(F.zerocopy_to_numpy(e_idx)) == {1} assert set(F.zerocopy_to_numpy(e_idx)) == {1}
@unittest.skipIf(F._default_context_str == 'cpu',
reason="CPU not yet supported")
@parametrize_dtype
def test_array_filter(idtype):
f = Filter(F.copy_to(F.tensor([0,1,9,4,6,5,7], dtype=idtype), F.ctx()))
x = F.copy_to(F.tensor([0,3,9,11], dtype=idtype), F.ctx())
y = F.copy_to(F.tensor([0,19,0,28,3,9,11,4,5], dtype=idtype), F.ctx())
xi_act = f.find_included_indices(x)
xi_exp = F.copy_to(F.tensor([0,2], dtype=idtype), F.ctx())
assert F.array_equal(xi_act, xi_exp)
xe_act = f.find_excluded_indices(x)
xe_exp = F.copy_to(F.tensor([1,3], dtype=idtype), F.ctx())
assert F.array_equal(xe_act, xe_exp)
yi_act = f.find_included_indices(y)
yi_exp = F.copy_to(F.tensor([0,2,5,7,8], dtype=idtype), F.ctx())
assert F.array_equal(yi_act, yi_exp)
ye_act = f.find_excluded_indices(y)
ye_exp = F.copy_to(F.tensor([1,3,4,6], dtype=idtype), F.ctx())
assert F.array_equal(ye_act, ye_exp)
if __name__ == '__main__': if __name__ == '__main__':
test_filter() test_graph_filter()
test_array_filter()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment