cuda_row_data.cpp 22.3 KB
Newer Older
1
2
3
4
5
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

6
#ifdef USE_CUDA
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

#include <LightGBM/cuda/cuda_row_data.hpp>

namespace LightGBM {

CUDARowData::CUDARowData(const Dataset* train_data,
                         const TrainingShareStates* train_share_state,
                         const int gpu_device_id,
                         const bool gpu_use_dp):
gpu_device_id_(gpu_device_id),
gpu_use_dp_(gpu_use_dp) {
  num_threads_ = OMP_NUM_THREADS();
  num_data_ = train_data->num_data();
  const auto& feature_hist_offsets = train_share_state->feature_hist_offsets();
  if (gpu_use_dp_) {
    shared_hist_size_ = DP_SHARED_HIST_SIZE;
  } else {
    shared_hist_size_ = SP_SHARED_HIST_SIZE;
  }
  if (feature_hist_offsets.empty()) {
    num_total_bin_ = 0;
  } else {
    num_total_bin_ = static_cast<int>(feature_hist_offsets.back());
  }
  num_feature_group_ = train_data->num_feature_groups();
  num_feature_ = train_data->num_features();
  if (gpu_device_id >= 0) {
    SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
  } else {
    SetCUDADevice(0, __FILE__, __LINE__);
  }
  cuda_data_uint8_t_ = nullptr;
  cuda_data_uint16_t_ = nullptr;
  cuda_data_uint32_t_ = nullptr;
  cuda_row_ptr_uint16_t_ = nullptr;
  cuda_row_ptr_uint32_t_ = nullptr;
  cuda_row_ptr_uint64_t_ = nullptr;
  cuda_partition_ptr_uint16_t_ = nullptr;
  cuda_partition_ptr_uint32_t_ = nullptr;
  cuda_partition_ptr_uint64_t_ = nullptr;
  cuda_feature_partition_column_index_offsets_ = nullptr;
  cuda_column_hist_offsets_ = nullptr;
  cuda_partition_hist_offsets_ = nullptr;
  cuda_block_buffer_uint16_t_ = nullptr;
  cuda_block_buffer_uint32_t_ = nullptr;
  cuda_block_buffer_uint64_t_ = nullptr;
}

CUDARowData::~CUDARowData() {
  DeallocateCUDAMemory<uint8_t>(&cuda_data_uint8_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint16_t>(&cuda_data_uint16_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint32_t>(&cuda_data_uint32_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint16_t>(&cuda_row_ptr_uint16_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint32_t>(&cuda_row_ptr_uint32_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint64_t>(&cuda_row_ptr_uint64_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_feature_partition_column_index_offsets_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint32_t>(&cuda_column_hist_offsets_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint32_t>(&cuda_partition_hist_offsets_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint16_t>(&cuda_block_buffer_uint16_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint32_t>(&cuda_block_buffer_uint32_t_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint64_t>(&cuda_block_buffer_uint64_t_, __FILE__, __LINE__);
}

void CUDARowData::Init(const Dataset* train_data, TrainingShareStates* train_share_state) {
  if (num_feature_ == 0) {
    return;
  }
  DivideCUDAFeatureGroups(train_data, train_share_state);
  bit_type_ = 0;
  size_t total_size = 0;
  const void* host_row_ptr = nullptr;
  row_ptr_bit_type_ = 0;
  const void* host_data = train_share_state->GetRowWiseData(&bit_type_, &total_size, &is_sparse_, &host_row_ptr, &row_ptr_bit_type_);
  if (bit_type_ == 8) {
    if (!is_sparse_) {
      std::vector<uint8_t> partitioned_data;
      GetDenseDataPartitioned<uint8_t>(reinterpret_cast<const uint8_t*>(host_data), &partitioned_data);
      InitCUDAMemoryFromHostMemory<uint8_t>(&cuda_data_uint8_t_, partitioned_data.data(), total_size, __FILE__, __LINE__);
    } else {
      if (row_ptr_bit_type_ == 16) {
        InitSparseData<uint8_t, uint16_t>(
          reinterpret_cast<const uint8_t*>(host_data),
          reinterpret_cast<const uint16_t*>(host_row_ptr),
          &cuda_data_uint8_t_,
          &cuda_row_ptr_uint16_t_,
          &cuda_partition_ptr_uint16_t_);
      } else if (row_ptr_bit_type_ == 32) {
        InitSparseData<uint8_t, uint32_t>(
          reinterpret_cast<const uint8_t*>(host_data),
          reinterpret_cast<const uint32_t*>(host_row_ptr),
          &cuda_data_uint8_t_,
          &cuda_row_ptr_uint32_t_,
          &cuda_partition_ptr_uint32_t_);
      } else if (row_ptr_bit_type_ == 64) {
        InitSparseData<uint8_t, uint64_t>(
          reinterpret_cast<const uint8_t*>(host_data),
          reinterpret_cast<const uint64_t*>(host_row_ptr),
          &cuda_data_uint8_t_,
          &cuda_row_ptr_uint64_t_,
          &cuda_partition_ptr_uint64_t_);
      } else {
108
        Log::Fatal("Unknown data ptr bit type %d", row_ptr_bit_type_);
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
      }
    }
  } else if (bit_type_ == 16) {
    if (!is_sparse_) {
      std::vector<uint16_t> partitioned_data;
      GetDenseDataPartitioned<uint16_t>(reinterpret_cast<const uint16_t*>(host_data), &partitioned_data);
      InitCUDAMemoryFromHostMemory<uint16_t>(&cuda_data_uint16_t_, partitioned_data.data(), total_size, __FILE__, __LINE__);
    } else {
      if (row_ptr_bit_type_ == 16) {
        InitSparseData<uint16_t, uint16_t>(
          reinterpret_cast<const uint16_t*>(host_data),
          reinterpret_cast<const uint16_t*>(host_row_ptr),
          &cuda_data_uint16_t_,
          &cuda_row_ptr_uint16_t_,
          &cuda_partition_ptr_uint16_t_);
      } else if (row_ptr_bit_type_ == 32) {
        InitSparseData<uint16_t, uint32_t>(
          reinterpret_cast<const uint16_t*>(host_data),
          reinterpret_cast<const uint32_t*>(host_row_ptr),
          &cuda_data_uint16_t_,
          &cuda_row_ptr_uint32_t_,
          &cuda_partition_ptr_uint32_t_);
      } else if (row_ptr_bit_type_ == 64) {
        InitSparseData<uint16_t, uint64_t>(
          reinterpret_cast<const uint16_t*>(host_data),
          reinterpret_cast<const uint64_t*>(host_row_ptr),
          &cuda_data_uint16_t_,
          &cuda_row_ptr_uint64_t_,
          &cuda_partition_ptr_uint64_t_);
      } else {
139
        Log::Fatal("Unknown data ptr bit type %d", row_ptr_bit_type_);
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
      }
    }
  } else if (bit_type_ == 32) {
    if (!is_sparse_) {
      std::vector<uint32_t> partitioned_data;
      GetDenseDataPartitioned<uint32_t>(reinterpret_cast<const uint32_t*>(host_data), &partitioned_data);
      InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_data_uint32_t_, partitioned_data.data(), total_size, __FILE__, __LINE__);
    } else {
      if (row_ptr_bit_type_ == 16) {
        InitSparseData<uint32_t, uint16_t>(
          reinterpret_cast<const uint32_t*>(host_data),
          reinterpret_cast<const uint16_t*>(host_row_ptr),
          &cuda_data_uint32_t_,
          &cuda_row_ptr_uint16_t_,
          &cuda_partition_ptr_uint16_t_);
      } else if (row_ptr_bit_type_ == 32) {
        InitSparseData<uint32_t, uint32_t>(
          reinterpret_cast<const uint32_t*>(host_data),
          reinterpret_cast<const uint32_t*>(host_row_ptr),
          &cuda_data_uint32_t_,
          &cuda_row_ptr_uint32_t_,
          &cuda_partition_ptr_uint32_t_);
      } else if (row_ptr_bit_type_ == 64) {
        InitSparseData<uint32_t, uint64_t>(
          reinterpret_cast<const uint32_t*>(host_data),
          reinterpret_cast<const uint64_t*>(host_row_ptr),
          &cuda_data_uint32_t_,
          &cuda_row_ptr_uint64_t_,
          &cuda_partition_ptr_uint64_t_);
      } else {
170
        Log::Fatal("Unknown data ptr bit type %d", row_ptr_bit_type_);
171
172
173
      }
    }
  } else {
174
    Log::Fatal("Unknown bit type = %d", bit_type_);
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
  }
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDARowData::DivideCUDAFeatureGroups(const Dataset* train_data, TrainingShareStates* share_state) {
  const uint32_t max_num_bin_per_partition = shared_hist_size_ / 2;
  const std::vector<uint32_t>& column_hist_offsets = share_state->column_hist_offsets();
  std::vector<int> feature_group_num_feature_offsets;
  int offsets = 0;
  int prev_group_index = -1;
  for (int feature_index = 0; feature_index < num_feature_; ++feature_index) {
    const int feature_group_index = train_data->Feature2Group(feature_index);
    if (prev_group_index == -1 || feature_group_index != prev_group_index) {
      feature_group_num_feature_offsets.emplace_back(offsets);
      prev_group_index = feature_group_index;
    }
    ++offsets;
  }
  CHECK_EQ(offsets, num_feature_);
  feature_group_num_feature_offsets.emplace_back(offsets);

  uint32_t start_hist_offset = 0;
  feature_partition_column_index_offsets_.clear();
  column_hist_offsets_.clear();
  partition_hist_offsets_.clear();
  feature_partition_column_index_offsets_.emplace_back(0);
  partition_hist_offsets_.emplace_back(0);
  const int num_feature_groups = train_data->num_feature_groups();
  int column_index = 0;
  num_feature_partitions_ = 0;
  large_bin_partitions_.clear();
  small_bin_partitions_.clear();
  for (int feature_group_index = 0; feature_group_index < num_feature_groups; ++feature_group_index) {
    if (!train_data->IsMultiGroup(feature_group_index)) {
      const uint32_t column_feature_hist_start = column_hist_offsets[column_index];
      const uint32_t column_feature_hist_end = column_hist_offsets[column_index + 1];
      const uint32_t num_bin_in_dense_group = column_feature_hist_end - column_feature_hist_start;

      // if one column has too many bins, use a separate partition for that column
      if (num_bin_in_dense_group > max_num_bin_per_partition) {
        feature_partition_column_index_offsets_.emplace_back(column_index + 1);
        start_hist_offset = column_feature_hist_end;
        partition_hist_offsets_.emplace_back(start_hist_offset);
        large_bin_partitions_.emplace_back(num_feature_partitions_);
        ++num_feature_partitions_;
        column_hist_offsets_.emplace_back(0);
        ++column_index;
        continue;
      }

      // try if adding this column exceed the maximum number per partition
      const uint32_t cur_hist_num_bin = column_feature_hist_end - start_hist_offset;
      if (cur_hist_num_bin > max_num_bin_per_partition) {
        feature_partition_column_index_offsets_.emplace_back(column_index);
        start_hist_offset = column_feature_hist_start;
        partition_hist_offsets_.emplace_back(start_hist_offset);
        small_bin_partitions_.emplace_back(num_feature_partitions_);
        ++num_feature_partitions_;
      }
      column_hist_offsets_.emplace_back(column_hist_offsets[column_index] - start_hist_offset);
      if (feature_group_index == num_feature_groups - 1) {
        feature_partition_column_index_offsets_.emplace_back(column_index + 1);
        partition_hist_offsets_.emplace_back(column_hist_offsets.back());
        small_bin_partitions_.emplace_back(num_feature_partitions_);
        ++num_feature_partitions_;
      }
      ++column_index;
    } else {
      const int group_feature_index_start = feature_group_num_feature_offsets[feature_group_index];
      const int num_feature_in_group = feature_group_num_feature_offsets[feature_group_index + 1] - group_feature_index_start;
      for (int sub_feature_index = 0; sub_feature_index < num_feature_in_group; ++sub_feature_index) {
        const int feature_index = group_feature_index_start + sub_feature_index;
        const uint32_t column_feature_hist_start = column_hist_offsets[column_index];
        const uint32_t column_feature_hist_end = column_hist_offsets[column_index + 1];
        const uint32_t num_bin_in_dense_group = column_feature_hist_end - column_feature_hist_start;

        // if one column has too many bins, use a separate partition for that column
        if (num_bin_in_dense_group > max_num_bin_per_partition) {
          feature_partition_column_index_offsets_.emplace_back(column_index + 1);
          start_hist_offset = column_feature_hist_end;
          partition_hist_offsets_.emplace_back(start_hist_offset);
          large_bin_partitions_.emplace_back(num_feature_partitions_);
          ++num_feature_partitions_;
          column_hist_offsets_.emplace_back(0);
          ++column_index;
          continue;
        }

        // try if adding this column exceed the maximum number per partition
        const uint32_t cur_hist_num_bin = column_feature_hist_end - start_hist_offset;
        if (cur_hist_num_bin > max_num_bin_per_partition) {
          feature_partition_column_index_offsets_.emplace_back(column_index);
          start_hist_offset = column_feature_hist_start;
          partition_hist_offsets_.emplace_back(start_hist_offset);
          small_bin_partitions_.emplace_back(num_feature_partitions_);
          ++num_feature_partitions_;
        }
        column_hist_offsets_.emplace_back(column_hist_offsets[column_index] - start_hist_offset);
        if (feature_group_index == num_feature_groups - 1 && sub_feature_index == num_feature_in_group - 1) {
          CHECK_EQ(feature_index, num_feature_ - 1);
          feature_partition_column_index_offsets_.emplace_back(column_index + 1);
          partition_hist_offsets_.emplace_back(column_hist_offsets.back());
          small_bin_partitions_.emplace_back(num_feature_partitions_);
          ++num_feature_partitions_;
        }
        ++column_index;
      }
    }
  }
  column_hist_offsets_.emplace_back(column_hist_offsets.back() - start_hist_offset);
  max_num_column_per_partition_ = 0;
  for (size_t i = 0; i < feature_partition_column_index_offsets_.size() - 1; ++i) {
    const int num_column = feature_partition_column_index_offsets_[i + 1] - feature_partition_column_index_offsets_[i];
    if (num_column > max_num_column_per_partition_) {
      max_num_column_per_partition_ = num_column;
    }
  }

  InitCUDAMemoryFromHostMemory<int>(&cuda_feature_partition_column_index_offsets_,
    feature_partition_column_index_offsets_.data(),
    feature_partition_column_index_offsets_.size(),
    __FILE__,
    __LINE__);

  InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_column_hist_offsets_,
    column_hist_offsets_.data(),
    column_hist_offsets_.size(),
    __FILE__,
    __LINE__);

  InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_partition_hist_offsets_,
    partition_hist_offsets_.data(),
    partition_hist_offsets_.size(),
    __FILE__,
    __LINE__);
}

template <typename BIN_TYPE>
void CUDARowData::GetDenseDataPartitioned(const BIN_TYPE* row_wise_data, std::vector<BIN_TYPE>* partitioned_data) {
  const int num_total_columns = feature_partition_column_index_offsets_.back();
  partitioned_data->resize(static_cast<size_t>(num_total_columns) * static_cast<size_t>(num_data_), 0);
  BIN_TYPE* out_data = partitioned_data->data();
  Threading::For<data_size_t>(0, num_data_, 512,
    [this, num_total_columns, row_wise_data, out_data] (int /*thread_index*/, data_size_t start, data_size_t end) {
      for (size_t i = 0; i < feature_partition_column_index_offsets_.size() - 1; ++i) {
        const int num_prev_columns = static_cast<int>(feature_partition_column_index_offsets_[i]);
321
        const size_t offset = static_cast<size_t>(num_data_) * static_cast<size_t>(num_prev_columns);
322
323
324
325
        const int partition_column_start = feature_partition_column_index_offsets_[i];
        const int partition_column_end = feature_partition_column_index_offsets_[i + 1];
        const int num_columns_in_cur_partition = partition_column_end - partition_column_start;
        for (data_size_t data_index = start; data_index < end; ++data_index) {
326
          const size_t data_offset = offset + static_cast<size_t>(data_index) * num_columns_in_cur_partition;
327
          const size_t read_data_offset = static_cast<size_t>(data_index) * num_total_columns;
328
          for (int column_index = 0; column_index < num_columns_in_cur_partition; ++column_index) {
329
            const size_t true_column_index = read_data_offset + column_index + partition_column_start;
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
            const BIN_TYPE bin = row_wise_data[true_column_index];
            out_data[data_offset + column_index] = bin;
          }
        }
      }
    });
}

template <typename BIN_TYPE, typename DATA_PTR_TYPE>
void CUDARowData::GetSparseDataPartitioned(
  const BIN_TYPE* row_wise_data,
  const DATA_PTR_TYPE* row_ptr,
  std::vector<std::vector<BIN_TYPE>>* partitioned_data,
  std::vector<std::vector<DATA_PTR_TYPE>>* partitioned_row_ptr,
  std::vector<DATA_PTR_TYPE>* partition_ptr) {
  const int num_partitions = static_cast<int>(feature_partition_column_index_offsets_.size()) - 1;
  partitioned_data->resize(num_partitions);
  partitioned_row_ptr->resize(num_partitions);
  std::vector<int> thread_max_elements_per_row(num_threads_, 0);
  Threading::For<int>(0, num_partitions, 1,
    [partitioned_data, partitioned_row_ptr, row_ptr, row_wise_data, &thread_max_elements_per_row, this] (int thread_index, int start, int end) {
      for (int partition_index = start; partition_index < end; ++partition_index) {
        std::vector<BIN_TYPE>& data_for_this_partition = partitioned_data->at(partition_index);
        std::vector<DATA_PTR_TYPE>& row_ptr_for_this_partition = partitioned_row_ptr->at(partition_index);
        const int partition_hist_start = partition_hist_offsets_[partition_index];
        const int partition_hist_end = partition_hist_offsets_[partition_index + 1];
        DATA_PTR_TYPE offset = 0;
        row_ptr_for_this_partition.clear();
        data_for_this_partition.clear();
        row_ptr_for_this_partition.emplace_back(offset);
        for (data_size_t data_index = 0; data_index < num_data_; ++data_index) {
          const DATA_PTR_TYPE row_start = row_ptr[data_index];
          const DATA_PTR_TYPE row_end = row_ptr[data_index + 1];
          const BIN_TYPE* row_data_start = row_wise_data + row_start;
          const BIN_TYPE* row_data_end = row_wise_data + row_end;
          const size_t partition_start_in_row = std::lower_bound(row_data_start, row_data_end, partition_hist_start) - row_data_start;
          const size_t partition_end_in_row = std::lower_bound(row_data_start, row_data_end, partition_hist_end) - row_data_start;
          for (size_t pos = partition_start_in_row; pos < partition_end_in_row; ++pos) {
            const BIN_TYPE bin = row_data_start[pos];
            CHECK_GE(bin, static_cast<BIN_TYPE>(partition_hist_start));
            data_for_this_partition.emplace_back(bin - partition_hist_start);
          }
          CHECK_GE(partition_end_in_row, partition_start_in_row);
          const data_size_t num_elements_in_row = partition_end_in_row - partition_start_in_row;
          offset += static_cast<DATA_PTR_TYPE>(num_elements_in_row);
          row_ptr_for_this_partition.emplace_back(offset);
          if (num_elements_in_row > thread_max_elements_per_row[thread_index]) {
            thread_max_elements_per_row[thread_index] = num_elements_in_row;
          }
        }
      }
    });
  partition_ptr->clear();
  DATA_PTR_TYPE offset = 0;
  partition_ptr->emplace_back(offset);
  for (size_t i = 0; i < partitioned_row_ptr->size(); ++i) {
    offset += partitioned_row_ptr->at(i).back();
    partition_ptr->emplace_back(offset);
  }
  max_num_column_per_partition_ = 0;
  for (int thread_index = 0; thread_index < num_threads_; ++thread_index) {
    if (thread_max_elements_per_row[thread_index] > max_num_column_per_partition_) {
      max_num_column_per_partition_ = thread_max_elements_per_row[thread_index];
    }
  }
}

template <typename BIN_TYPE, typename ROW_PTR_TYPE>
void CUDARowData::InitSparseData(const BIN_TYPE* host_data,
                                 const ROW_PTR_TYPE* host_row_ptr,
                                 BIN_TYPE** cuda_data,
                                 ROW_PTR_TYPE** cuda_row_ptr,
                                 ROW_PTR_TYPE** cuda_partition_ptr) {
  std::vector<std::vector<BIN_TYPE>> partitioned_data;
  std::vector<std::vector<ROW_PTR_TYPE>> partitioned_data_ptr;
  std::vector<ROW_PTR_TYPE> partition_ptr;
  GetSparseDataPartitioned<BIN_TYPE, ROW_PTR_TYPE>(host_data, host_row_ptr, &partitioned_data, &partitioned_data_ptr, &partition_ptr);
  InitCUDAMemoryFromHostMemory<ROW_PTR_TYPE>(cuda_partition_ptr, partition_ptr.data(), partition_ptr.size(), __FILE__, __LINE__);
  AllocateCUDAMemory<BIN_TYPE>(cuda_data, partition_ptr.back(), __FILE__, __LINE__);
  AllocateCUDAMemory<ROW_PTR_TYPE>(cuda_row_ptr, (num_data_ + 1) * partitioned_data_ptr.size(),  __FILE__, __LINE__);
  for (size_t i = 0; i < partitioned_data.size(); ++i) {
    const std::vector<ROW_PTR_TYPE>& data_ptr_for_this_partition = partitioned_data_ptr[i];
    const std::vector<BIN_TYPE>& data_for_this_partition = partitioned_data[i];
    CopyFromHostToCUDADevice<BIN_TYPE>((*cuda_data) + partition_ptr[i], data_for_this_partition.data(), data_for_this_partition.size(), __FILE__, __LINE__);
    CopyFromHostToCUDADevice<ROW_PTR_TYPE>((*cuda_row_ptr) + i * (num_data_ + 1), data_ptr_for_this_partition.data(), data_ptr_for_this_partition.size(), __FILE__, __LINE__);
  }
}

template <typename BIN_TYPE>
const BIN_TYPE* CUDARowData::GetBin() const {
  if (bit_type_ == 8) {
    return reinterpret_cast<const BIN_TYPE*>(cuda_data_uint8_t_);
  } else if (bit_type_ == 16) {
    return reinterpret_cast<const BIN_TYPE*>(cuda_data_uint16_t_);
  } else if (bit_type_ == 32) {
    return reinterpret_cast<const BIN_TYPE*>(cuda_data_uint32_t_);
  } else {
    Log::Fatal("Unknown bit_type %d for GetBin.", bit_type_);
  }
}

template const uint8_t* CUDARowData::GetBin<uint8_t>() const;

template const uint16_t* CUDARowData::GetBin<uint16_t>() const;

template const uint32_t* CUDARowData::GetBin<uint32_t>() const;

template <typename PTR_TYPE>
const PTR_TYPE* CUDARowData::GetRowPtr() const {
  if (row_ptr_bit_type_ == 16) {
    return reinterpret_cast<const PTR_TYPE*>(cuda_row_ptr_uint16_t_);
  } else if (row_ptr_bit_type_ == 32) {
    return reinterpret_cast<const PTR_TYPE*>(cuda_row_ptr_uint32_t_);
  } else if (row_ptr_bit_type_ == 64) {
    return reinterpret_cast<const PTR_TYPE*>(cuda_row_ptr_uint64_t_);
  } else {
    Log::Fatal("Unknown row_ptr_bit_type = %d for GetRowPtr.", row_ptr_bit_type_);
  }
}

template const uint16_t* CUDARowData::GetRowPtr<uint16_t>() const;

template const uint32_t* CUDARowData::GetRowPtr<uint32_t>() const;

template const uint64_t* CUDARowData::GetRowPtr<uint64_t>() const;

template <typename PTR_TYPE>
const PTR_TYPE* CUDARowData::GetPartitionPtr() const {
  if (row_ptr_bit_type_ == 16) {
    return reinterpret_cast<const PTR_TYPE*>(cuda_partition_ptr_uint16_t_);
  } else if (row_ptr_bit_type_ == 32) {
    return reinterpret_cast<const PTR_TYPE*>(cuda_partition_ptr_uint32_t_);
  } else if (row_ptr_bit_type_ == 64) {
    return reinterpret_cast<const PTR_TYPE*>(cuda_partition_ptr_uint64_t_);
  } else {
    Log::Fatal("Unknown row_ptr_bit_type = %d for GetPartitionPtr.", row_ptr_bit_type_);
  }
}

template const uint16_t* CUDARowData::GetPartitionPtr<uint16_t>() const;

template const uint32_t* CUDARowData::GetPartitionPtr<uint32_t>() const;

template const uint64_t* CUDARowData::GetPartitionPtr<uint64_t>() const;

}  // namespace LightGBM

477
#endif  // USE_CUDA