cuda_metadata.cpp 3.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_metadata.hpp>

namespace LightGBM {

CUDAMetadata::CUDAMetadata(const int gpu_device_id) {
  if (gpu_device_id >= 0) {
    SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
  } else {
    SetCUDADevice(0, __FILE__, __LINE__);
  }
  cuda_label_ = nullptr;
  cuda_weights_ = nullptr;
  cuda_query_boundaries_ = nullptr;
  cuda_query_weights_ = nullptr;
  cuda_init_score_ = nullptr;
}

CUDAMetadata::~CUDAMetadata() {
  DeallocateCUDAMemory<label_t>(&cuda_label_, __FILE__, __LINE__);
  DeallocateCUDAMemory<label_t>(&cuda_weights_, __FILE__, __LINE__);
  DeallocateCUDAMemory<data_size_t>(&cuda_query_boundaries_, __FILE__, __LINE__);
  DeallocateCUDAMemory<label_t>(&cuda_query_weights_, __FILE__, __LINE__);
  DeallocateCUDAMemory<double>(&cuda_init_score_, __FILE__, __LINE__);
}

void CUDAMetadata::Init(const std::vector<label_t>& label,
                        const std::vector<label_t>& weight,
                        const std::vector<data_size_t>& query_boundaries,
                        const std::vector<label_t>& query_weights,
                        const std::vector<double>& init_score) {
  if (label.size() == 0) {
    cuda_label_ = nullptr;
  } else {
    InitCUDAMemoryFromHostMemory<label_t>(&cuda_label_, label.data(), label.size(), __FILE__, __LINE__);
  }
  if (weight.size() == 0) {
    cuda_weights_ = nullptr;
  } else {
    InitCUDAMemoryFromHostMemory<label_t>(&cuda_weights_, weight.data(), weight.size(), __FILE__, __LINE__);
  }
  if (query_boundaries.size() == 0) {
    cuda_query_boundaries_ = nullptr;
  } else {
    InitCUDAMemoryFromHostMemory<data_size_t>(&cuda_query_boundaries_, query_boundaries.data(), query_boundaries.size(), __FILE__, __LINE__);
  }
  if (query_weights.size() == 0) {
    cuda_query_weights_ = nullptr;
  } else {
    InitCUDAMemoryFromHostMemory<label_t>(&cuda_query_weights_, query_weights.data(), query_weights.size(), __FILE__, __LINE__);
  }
  if (init_score.size() == 0) {
    cuda_init_score_ = nullptr;
  } else {
    InitCUDAMemoryFromHostMemory<double>(&cuda_init_score_, init_score.data(), init_score.size(), __FILE__, __LINE__);
  }
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDAMetadata::SetLabel(const label_t* label, data_size_t len) {
  DeallocateCUDAMemory<label_t>(&cuda_label_, __FILE__, __LINE__);
  InitCUDAMemoryFromHostMemory<label_t>(&cuda_label_, label, static_cast<size_t>(len), __FILE__, __LINE__);
}

void CUDAMetadata::SetWeights(const label_t* weights, data_size_t len) {
  DeallocateCUDAMemory<label_t>(&cuda_weights_, __FILE__, __LINE__);
  InitCUDAMemoryFromHostMemory<label_t>(&cuda_weights_, weights, static_cast<size_t>(len), __FILE__, __LINE__);
}

void CUDAMetadata::SetQuery(const data_size_t* query_boundaries, const label_t* query_weights, data_size_t num_queries) {
  DeallocateCUDAMemory<data_size_t>(&cuda_query_boundaries_, __FILE__, __LINE__);
  InitCUDAMemoryFromHostMemory<data_size_t>(&cuda_query_boundaries_, query_boundaries, static_cast<size_t>(num_queries) + 1, __FILE__, __LINE__);
  if (query_weights != nullptr) {
    DeallocateCUDAMemory<label_t>(&cuda_query_weights_, __FILE__, __LINE__);
    InitCUDAMemoryFromHostMemory<label_t>(&cuda_query_weights_, query_weights, static_cast<size_t>(num_queries), __FILE__, __LINE__);
  }
}

void CUDAMetadata::SetInitScore(const double* init_score, data_size_t len) {
  DeallocateCUDAMemory<double>(&cuda_init_score_, __FILE__, __LINE__);
  InitCUDAMemoryFromHostMemory<double>(&cuda_init_score_, init_score, static_cast<size_t>(len), __FILE__, __LINE__);
}

}  // namespace LightGBM

#endif  // USE_CUDA_EXP