Unverified Commit 751b4c26 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Bugfix] Replace global cudaStream in Filter with runtime calls (fix #5153) (#5157)



* Add failing unit test

* Add fix

* Remove extra newline

* skip cpu test
Co-authored-by: default avatarXin Yao <yaox12@outlook.com>
parent 84e4d021
......@@ -18,8 +18,6 @@ namespace array {
namespace {
cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
template <typename IdType, bool include>
__global__ void _IsInKernel(
DeviceOrderedHashTable<IdType> table, const IdType* const array,
......@@ -46,6 +44,7 @@ IdArray _PerformFilter(const OrderedHashTable<IdType>& table, IdArray test) {
const auto& ctx = test->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t size = test->shape[0];
cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
if (size == 0) {
return test;
......@@ -108,7 +107,8 @@ template <typename IdType>
class CudaFilterSet : public Filter {
public:
explicit CudaFilterSet(IdArray array)
: table_(array->shape[0], array->ctx, cudaStream) {
: table_(array->shape[0], array->ctx, runtime::getCurrentCUDAStream()) {
cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
table_.FillWithUnique(
static_cast<const IdType*>(array->data), array->shape[0], cudaStream);
}
......
import unittest
import backend as F
import numpy as np
from test_utils import parametrize_idtype
import dgl
import numpy as np
from dgl.utils import Filter
from test_utils import parametrize_idtype
def test_graph_filter():
......@@ -71,6 +71,28 @@ def test_array_filter(idtype):
assert F.array_equal(ye_act, ye_exp)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch",
reason="Multiple streams are only supported by pytorch backend",
)
@unittest.skipIf(
F._default_context_str == "cpu", reason="CPU not yet supported"
)
@parametrize_idtype
def test_filter_multistream(idtype):
# this is a smoke test to ensure we do not trip any internal assertions
import torch
s = torch.cuda.Stream(device=F.ctx())
with torch.cuda.stream(s):
# we must do multiple runs such that the stream is busy as we launch
# work
for i in range(10):
f = Filter(F.arange(1000, 4000, dtype=idtype, ctx=F.ctx()))
x = F.randint([30000], dtype=idtype, ctx=F.ctx(), low=0, high=50000)
xi = f.find_included_indices(x)
if __name__ == "__main__":
test_graph_filter()
test_array_filter()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment