Unverified Commit d3ae7544 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] aten::Relabel_() for the GPU (#3445)



* relabel gpu

* unittest for ralebl_ on the GPU

* finish Relabel_ for the GPU

* copyright

* re-enable the unittest for edge_subgrah on the GPU

* fix unittest for tensorflow

* use a fixed number of threads
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent f46080a4
/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019-2021 by Contributors
* \file array/array.cc
* \brief DGL array utilities implementation
*/
......@@ -205,7 +205,7 @@ NDArray Repeat(NDArray array, IdArray repeats) {
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray ret;
ATEN_XPU_SWITCH(arrays[0]->ctx.device_type, XPU, "Relabel_", {
ATEN_XPU_SWITCH_CUDA(arrays[0]->ctx.device_type, XPU, "Relabel_", {
ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
ret = impl::Relabel_<XPU, IdType>(arrays);
});
......
/*!
* Copyright (c) 2020 by Contributors
* Copyright (c) 2020-2021 by Contributors
* \file array/cuda/array_op_impl.cu
* \brief Array operator GPU implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_hashtable.cuh"
#include "./utils.h"
#include "../arith.h"
namespace dgl {
using runtime::NDArray;
using namespace runtime::cuda;
namespace aten {
namespace impl {
......@@ -258,6 +260,84 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// Relabel_ //////////////////////////////
template <typename IdType>
__global__ void _RelabelKernel(
IdType* out, int64_t length, DeviceOrderedHashTable<IdType> table) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = table.Search(out[tx])->local;
tx += stride_x;
}
}
template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray all_nodes = Concat(arrays);
const int64_t total_length = all_nodes->shape[0];
if (total_length == 0) {
return all_nodes;
}
const auto& ctx = arrays[0]->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// build node maps and get the induced nodes
OrderedHashTable<IdType> node_map(total_length, ctx, thr_entry->stream);
int64_t num_induced = 0;
int64_t * num_induced_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)));
IdArray induced_nodes = NewIdArray(total_length, ctx, sizeof(IdType)*8);
CUDA_CALL(cudaMemsetAsync(
num_induced_device,
0,
sizeof(*num_induced_device),
thr_entry->stream));
node_map.FillWithDuplicates(
all_nodes.Ptr<IdType>(),
all_nodes->shape[0],
induced_nodes.Ptr<IdType>(),
num_induced_device,
thr_entry->stream);
device->CopyDataFromTo(
num_induced_device, 0,
&num_induced, 0,
sizeof(num_induced),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
thr_entry->stream);
device->StreamSync(ctx, thr_entry->stream);
device->FreeWorkspace(ctx, num_induced_device);
// resize the induced nodes
induced_nodes->shape[0] = num_induced;
// relabel
const int nt = 128;
for (IdArray arr : arrays) {
const int64_t length = arr->shape[0];
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RelabelKernel<IdType>),
nb, nt, 0, thr_entry->stream,
arr.Ptr<IdType>(), length, node_map.DeviceHandle());
}
return induced_nodes;
}
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits /////////////////////////////
template <typename InType, typename OutType>
......
......@@ -48,8 +48,6 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) {
// TODO(minjie): In general, all relabeling should be separated with subgraph
// operations.
CHECK(hg->Context().device_type != kDLGPU)
<< "Edge subgraph with relabeling does not support GPU.";
CHECK_EQ(eids.size(), hg->NumEdgeTypes())
<< "Invalid input: the input list size must be the same as the number of edge type.";
HeteroSubgraph ret;
......
......@@ -28,7 +28,6 @@ def generate_graph(grad=False, add_data=True):
g.edata['l'] = ecol
return g
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_edge_subgraph():
# Test when the graph has no node data and edge data.
g = generate_graph(add_data=False)
......@@ -140,8 +139,6 @@ def test_subgraph_mask(idtype):
sg1 = g.subgraph({'user': F.tensor([False, True, True], dtype=F.bool),
'game': F.tensor([True, False, False, False], dtype=F.bool)})
_check_subgraph(g, sg1)
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
sg2 = g.edge_subgraph({'follows': F.tensor([False, True], dtype=F.bool),
'plays': F.tensor([False, True, False, False], dtype=F.bool),
'wishes': F.tensor([False, True], dtype=F.bool)})
......@@ -181,8 +178,6 @@ def test_subgraph1(idtype):
sg1 = g.subgraph({'user': [1, 2], 'game': [0]})
_check_subgraph(g, sg1)
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2)
......@@ -190,8 +185,6 @@ def test_subgraph1(idtype):
sg1 = g.subgraph({'user': F.tensor([1, 2], dtype=idtype),
'game': F.tensor([0], dtype=idtype)})
_check_subgraph(g, sg1)
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
sg2 = g.edge_subgraph({'follows': F.tensor([1], dtype=idtype),
'plays': F.tensor([1], dtype=idtype),
'wishes': F.tensor([1], dtype=idtype)})
......@@ -201,8 +194,6 @@ def test_subgraph1(idtype):
sg1 = g.subgraph({'user': np.array([1, 2]),
'game': np.array([0])})
_check_subgraph(g, sg1)
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
sg2 = g.edge_subgraph({'follows': np.array([1]),
'plays': np.array([1]),
'wishes': np.array([1])})
......@@ -248,8 +239,6 @@ def test_subgraph1(idtype):
sg1_graph = g_graph.subgraph([1, 2])
_check_subgraph_single_ntype(g_graph, sg1_graph)
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
sg1_graph = g_graph.edge_subgraph([1])
_check_subgraph_single_ntype(g_graph, sg1_graph)
sg1_graph = g_graph.edge_subgraph([1], relabel_nodes=False)
......@@ -297,8 +286,6 @@ def test_subgraph1(idtype):
_check_typed_subgraph1(g, sg5)
# Test for restricted format
if F._default_context_str != 'gpu':
# TODO(minjie): enable this later
for fmt in ['csr', 'csc', 'coo']:
g = dgl.graph(([0, 1], [1, 2])).formats(fmt)
sg = g.subgraph({g.ntypes[0]: [1, 0]})
......@@ -309,8 +296,6 @@ def test_subgraph1(idtype):
dst = F.asnumpy(dst)
assert np.array_equal(src, np.array([1]))
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
@parametrize_dtype
def test_in_subgraph(idtype):
hg = dgl.heterograph({
......@@ -318,7 +303,7 @@ def test_in_subgraph(idtype):
('user', 'play', 'game'): ([0, 0, 1, 3], [0, 1, 2, 2]),
('game', 'liked-by', 'user'): ([2, 2, 2, 1, 1, 0], [0, 1, 2, 0, 3, 0]),
('user', 'flips', 'coin'): ([0, 1, 2, 3], [0, 0, 0, 0])
}, idtype=idtype, num_nodes_dict={'user': 5, 'game': 10, 'coin': 8})
}, idtype=idtype, num_nodes_dict={'user': 5, 'game': 10, 'coin': 8}).to(F.ctx())
subg = dgl.in_subgraph(hg, {'user' : [0,1], 'game' : 0})
assert subg.idtype == idtype
assert len(subg.ntypes) == 3
......@@ -378,7 +363,6 @@ def test_in_subgraph(idtype):
assert subg.num_nodes('coin') == 0
assert subg.num_edges('flips') == 0
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
@parametrize_dtype
def test_out_subgraph(idtype):
hg = dgl.heterograph({
......@@ -386,7 +370,7 @@ def test_out_subgraph(idtype):
('user', 'play', 'game'): ([0, 0, 1, 3], [0, 1, 2, 2]),
('game', 'liked-by', 'user'): ([2, 2, 2, 1, 1, 0], [0, 1, 2, 0, 3, 0]),
('user', 'flips', 'coin'): ([0, 1, 2, 3], [0, 0, 0, 0])
}, idtype=idtype)
}, idtype=idtype).to(F.ctx())
subg = dgl.out_subgraph(hg, {'user' : [0,1], 'game' : 0})
assert subg.idtype == idtype
assert len(subg.ntypes) == 3
......
......@@ -242,14 +242,14 @@ TEST(ArrayTest, TestIndexSelect) {
}
template <typename IDX>
void _TestRelabel_() {
IdArray a = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX);
IdArray b = aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX)*8, CTX);
void _TestRelabel_(DLContext ctx) {
IdArray a = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, ctx);
IdArray b = aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX)*8, ctx);
IdArray c = aten::Relabel_({a, b});
IdArray ta = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
IdArray tb = aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX)*8, CTX);
IdArray tc = aten::VecToIdArray(std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX)*8, CTX);
IdArray ta = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
IdArray tb = aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX)*8, ctx);
IdArray tc = aten::VecToIdArray(std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(a, ta));
ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
......@@ -257,8 +257,12 @@ void _TestRelabel_() {
}
TEST(ArrayTest, TestRelabel_) {
_TestRelabel_<int32_t>();
_TestRelabel_<int64_t>();
_TestRelabel_<int32_t>(CPU);
_TestRelabel_<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestRelabel_<int32_t>(GPU);
_TestRelabel_<int64_t>(GPU);
#endif
}
template <typename IDX>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment