utils.cu 992 Bytes
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/utils.cu
 * @brief Utilities for CUDA kernels.
5
6
 */

7
8
#include <cub/cub.cuh>

9
#include "../../runtime/cuda/cuda_common.h"
10
#include "./utils.h"
11
12
13
14

namespace dgl {
namespace cuda {

15
bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) {
16
17
18
19
  auto device = runtime::DeviceAPI::Get(ctx);
  int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));
  // Call CUB's reduction
  size_t workspace_size = 0;
20
  cudaStream_t stream = runtime::getCurrentCUDAStream();
21
22
  CUDA_CALL(cub::DeviceReduce::Min(
      nullptr, workspace_size, flags, rst, length, stream));
23
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
24
25
  CUDA_CALL(cub::DeviceReduce::Min(
      workspace, workspace_size, flags, rst, length, stream));
26
  int8_t cpu_rst = GetCUDAScalar(device, ctx, rst);
27
28
29
30
31
32
33
  device->FreeWorkspace(ctx, workspace);
  device->FreeWorkspace(ctx, rst);
  return cpu_rst == 1;
}

}  // namespace cuda
}  // namespace dgl