cuda_utils.cpp 809 Bytes
Newer Older
1
2
3
4
5
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

6
#ifdef USE_CUDA
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

#include <LightGBM/cuda/cuda_utils.h>

namespace LightGBM {

void SynchronizeCUDADevice(const char* file, const int line) {
  gpuAssert(cudaDeviceSynchronize(), file, line);
}

void PrintLastCUDAError() {
  const char* error_name = cudaGetErrorName(cudaGetLastError());
  Log::Fatal(error_name);
}

void SetCUDADevice(int gpu_device_id, const char* file, int line) {
  int cur_gpu_device_id = 0;
  CUDASUCCESS_OR_FATAL_OUTER(cudaGetDevice(&cur_gpu_device_id));
  if (cur_gpu_device_id != gpu_device_id) {
    CUDASUCCESS_OR_FATAL_OUTER(cudaSetDevice(gpu_device_id));
  }
}

}  // namespace LightGBM

31
#endif  // USE_CUDA