/*! * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifdef USE_CUDA #include namespace LightGBM { CUDARowData::CUDARowData(const Dataset* train_data, const TrainingShareStates* train_share_state, const int gpu_device_id, const bool gpu_use_dp): gpu_device_id_(gpu_device_id), gpu_use_dp_(gpu_use_dp) { num_threads_ = OMP_NUM_THREADS(); num_data_ = train_data->num_data(); const auto& feature_hist_offsets = train_share_state->feature_hist_offsets(); if (gpu_use_dp_) { shared_hist_size_ = DP_SHARED_HIST_SIZE; } else { shared_hist_size_ = SP_SHARED_HIST_SIZE; } if (feature_hist_offsets.empty()) { num_total_bin_ = 0; } else { num_total_bin_ = static_cast(feature_hist_offsets.back()); } num_feature_group_ = train_data->num_feature_groups(); num_feature_ = train_data->num_features(); if (gpu_device_id >= 0) { SetCUDADevice(gpu_device_id, __FILE__, __LINE__); } else { SetCUDADevice(0, __FILE__, __LINE__); } cuda_data_uint8_t_ = nullptr; cuda_data_uint16_t_ = nullptr; cuda_data_uint32_t_ = nullptr; cuda_row_ptr_uint16_t_ = nullptr; cuda_row_ptr_uint32_t_ = nullptr; cuda_row_ptr_uint64_t_ = nullptr; cuda_partition_ptr_uint16_t_ = nullptr; cuda_partition_ptr_uint32_t_ = nullptr; cuda_partition_ptr_uint64_t_ = nullptr; cuda_feature_partition_column_index_offsets_ = nullptr; cuda_column_hist_offsets_ = nullptr; cuda_partition_hist_offsets_ = nullptr; cuda_block_buffer_uint16_t_ = nullptr; cuda_block_buffer_uint32_t_ = nullptr; cuda_block_buffer_uint64_t_ = nullptr; } CUDARowData::~CUDARowData() { DeallocateCUDAMemory(&cuda_data_uint8_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_data_uint16_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_data_uint32_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_row_ptr_uint16_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_row_ptr_uint32_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_row_ptr_uint64_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_feature_partition_column_index_offsets_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_column_hist_offsets_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_partition_hist_offsets_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_block_buffer_uint16_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_block_buffer_uint32_t_, __FILE__, __LINE__); DeallocateCUDAMemory(&cuda_block_buffer_uint64_t_, __FILE__, __LINE__); } void CUDARowData::Init(const Dataset* train_data, TrainingShareStates* train_share_state) { if (num_feature_ == 0) { return; } DivideCUDAFeatureGroups(train_data, train_share_state); bit_type_ = 0; size_t total_size = 0; const void* host_row_ptr = nullptr; row_ptr_bit_type_ = 0; const void* host_data = train_share_state->GetRowWiseData(&bit_type_, &total_size, &is_sparse_, &host_row_ptr, &row_ptr_bit_type_); if (bit_type_ == 8) { if (!is_sparse_) { std::vector partitioned_data; GetDenseDataPartitioned(reinterpret_cast(host_data), &partitioned_data); InitCUDAMemoryFromHostMemory(&cuda_data_uint8_t_, partitioned_data.data(), total_size, __FILE__, __LINE__); } else { if (row_ptr_bit_type_ == 16) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint8_t_, &cuda_row_ptr_uint16_t_, &cuda_partition_ptr_uint16_t_); } else if (row_ptr_bit_type_ == 32) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint8_t_, &cuda_row_ptr_uint32_t_, &cuda_partition_ptr_uint32_t_); } else if (row_ptr_bit_type_ == 64) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint8_t_, &cuda_row_ptr_uint64_t_, &cuda_partition_ptr_uint64_t_); } else { Log::Fatal("Unknown data ptr bit type %d", row_ptr_bit_type_); } } } else if (bit_type_ == 16) { if (!is_sparse_) { std::vector partitioned_data; GetDenseDataPartitioned(reinterpret_cast(host_data), &partitioned_data); InitCUDAMemoryFromHostMemory(&cuda_data_uint16_t_, partitioned_data.data(), total_size, __FILE__, __LINE__); } else { if (row_ptr_bit_type_ == 16) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint16_t_, &cuda_row_ptr_uint16_t_, &cuda_partition_ptr_uint16_t_); } else if (row_ptr_bit_type_ == 32) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint16_t_, &cuda_row_ptr_uint32_t_, &cuda_partition_ptr_uint32_t_); } else if (row_ptr_bit_type_ == 64) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint16_t_, &cuda_row_ptr_uint64_t_, &cuda_partition_ptr_uint64_t_); } else { Log::Fatal("Unknown data ptr bit type %d", row_ptr_bit_type_); } } } else if (bit_type_ == 32) { if (!is_sparse_) { std::vector partitioned_data; GetDenseDataPartitioned(reinterpret_cast(host_data), &partitioned_data); InitCUDAMemoryFromHostMemory(&cuda_data_uint32_t_, partitioned_data.data(), total_size, __FILE__, __LINE__); } else { if (row_ptr_bit_type_ == 16) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint32_t_, &cuda_row_ptr_uint16_t_, &cuda_partition_ptr_uint16_t_); } else if (row_ptr_bit_type_ == 32) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint32_t_, &cuda_row_ptr_uint32_t_, &cuda_partition_ptr_uint32_t_); } else if (row_ptr_bit_type_ == 64) { InitSparseData( reinterpret_cast(host_data), reinterpret_cast(host_row_ptr), &cuda_data_uint32_t_, &cuda_row_ptr_uint64_t_, &cuda_partition_ptr_uint64_t_); } else { Log::Fatal("Unknown data ptr bit type %d", row_ptr_bit_type_); } } } else { Log::Fatal("Unknown bit type = %d", bit_type_); } SynchronizeCUDADevice(__FILE__, __LINE__); } void CUDARowData::DivideCUDAFeatureGroups(const Dataset* train_data, TrainingShareStates* share_state) { const uint32_t max_num_bin_per_partition = shared_hist_size_ / 2; const std::vector& column_hist_offsets = share_state->column_hist_offsets(); std::vector feature_group_num_feature_offsets; int offsets = 0; int prev_group_index = -1; for (int feature_index = 0; feature_index < num_feature_; ++feature_index) { const int feature_group_index = train_data->Feature2Group(feature_index); if (prev_group_index == -1 || feature_group_index != prev_group_index) { feature_group_num_feature_offsets.emplace_back(offsets); prev_group_index = feature_group_index; } ++offsets; } CHECK_EQ(offsets, num_feature_); feature_group_num_feature_offsets.emplace_back(offsets); uint32_t start_hist_offset = 0; feature_partition_column_index_offsets_.clear(); column_hist_offsets_.clear(); partition_hist_offsets_.clear(); feature_partition_column_index_offsets_.emplace_back(0); partition_hist_offsets_.emplace_back(0); const int num_feature_groups = train_data->num_feature_groups(); int column_index = 0; num_feature_partitions_ = 0; large_bin_partitions_.clear(); small_bin_partitions_.clear(); for (int feature_group_index = 0; feature_group_index < num_feature_groups; ++feature_group_index) { if (!train_data->IsMultiGroup(feature_group_index)) { const uint32_t column_feature_hist_start = column_hist_offsets[column_index]; const uint32_t column_feature_hist_end = column_hist_offsets[column_index + 1]; const uint32_t num_bin_in_dense_group = column_feature_hist_end - column_feature_hist_start; // if one column has too many bins, use a separate partition for that column if (num_bin_in_dense_group > max_num_bin_per_partition) { feature_partition_column_index_offsets_.emplace_back(column_index + 1); start_hist_offset = column_feature_hist_end; partition_hist_offsets_.emplace_back(start_hist_offset); large_bin_partitions_.emplace_back(num_feature_partitions_); ++num_feature_partitions_; column_hist_offsets_.emplace_back(0); ++column_index; continue; } // try if adding this column exceed the maximum number per partition const uint32_t cur_hist_num_bin = column_feature_hist_end - start_hist_offset; if (cur_hist_num_bin > max_num_bin_per_partition) { feature_partition_column_index_offsets_.emplace_back(column_index); start_hist_offset = column_feature_hist_start; partition_hist_offsets_.emplace_back(start_hist_offset); small_bin_partitions_.emplace_back(num_feature_partitions_); ++num_feature_partitions_; } column_hist_offsets_.emplace_back(column_hist_offsets[column_index] - start_hist_offset); if (feature_group_index == num_feature_groups - 1) { feature_partition_column_index_offsets_.emplace_back(column_index + 1); partition_hist_offsets_.emplace_back(column_hist_offsets.back()); small_bin_partitions_.emplace_back(num_feature_partitions_); ++num_feature_partitions_; } ++column_index; } else { const int group_feature_index_start = feature_group_num_feature_offsets[feature_group_index]; const int num_feature_in_group = feature_group_num_feature_offsets[feature_group_index + 1] - group_feature_index_start; for (int sub_feature_index = 0; sub_feature_index < num_feature_in_group; ++sub_feature_index) { const int feature_index = group_feature_index_start + sub_feature_index; const uint32_t column_feature_hist_start = column_hist_offsets[column_index]; const uint32_t column_feature_hist_end = column_hist_offsets[column_index + 1]; const uint32_t num_bin_in_dense_group = column_feature_hist_end - column_feature_hist_start; // if one column has too many bins, use a separate partition for that column if (num_bin_in_dense_group > max_num_bin_per_partition) { feature_partition_column_index_offsets_.emplace_back(column_index + 1); start_hist_offset = column_feature_hist_end; partition_hist_offsets_.emplace_back(start_hist_offset); large_bin_partitions_.emplace_back(num_feature_partitions_); ++num_feature_partitions_; column_hist_offsets_.emplace_back(0); ++column_index; continue; } // try if adding this column exceed the maximum number per partition const uint32_t cur_hist_num_bin = column_feature_hist_end - start_hist_offset; if (cur_hist_num_bin > max_num_bin_per_partition) { feature_partition_column_index_offsets_.emplace_back(column_index); start_hist_offset = column_feature_hist_start; partition_hist_offsets_.emplace_back(start_hist_offset); small_bin_partitions_.emplace_back(num_feature_partitions_); ++num_feature_partitions_; } column_hist_offsets_.emplace_back(column_hist_offsets[column_index] - start_hist_offset); if (feature_group_index == num_feature_groups - 1 && sub_feature_index == num_feature_in_group - 1) { CHECK_EQ(feature_index, num_feature_ - 1); feature_partition_column_index_offsets_.emplace_back(column_index + 1); partition_hist_offsets_.emplace_back(column_hist_offsets.back()); small_bin_partitions_.emplace_back(num_feature_partitions_); ++num_feature_partitions_; } ++column_index; } } } column_hist_offsets_.emplace_back(column_hist_offsets.back() - start_hist_offset); max_num_column_per_partition_ = 0; for (size_t i = 0; i < feature_partition_column_index_offsets_.size() - 1; ++i) { const int num_column = feature_partition_column_index_offsets_[i + 1] - feature_partition_column_index_offsets_[i]; if (num_column > max_num_column_per_partition_) { max_num_column_per_partition_ = num_column; } } InitCUDAMemoryFromHostMemory(&cuda_feature_partition_column_index_offsets_, feature_partition_column_index_offsets_.data(), feature_partition_column_index_offsets_.size(), __FILE__, __LINE__); InitCUDAMemoryFromHostMemory(&cuda_column_hist_offsets_, column_hist_offsets_.data(), column_hist_offsets_.size(), __FILE__, __LINE__); InitCUDAMemoryFromHostMemory(&cuda_partition_hist_offsets_, partition_hist_offsets_.data(), partition_hist_offsets_.size(), __FILE__, __LINE__); } template void CUDARowData::GetDenseDataPartitioned(const BIN_TYPE* row_wise_data, std::vector* partitioned_data) { const int num_total_columns = feature_partition_column_index_offsets_.back(); partitioned_data->resize(static_cast(num_total_columns) * static_cast(num_data_), 0); BIN_TYPE* out_data = partitioned_data->data(); Threading::For(0, num_data_, 512, [this, num_total_columns, row_wise_data, out_data] (int /*thread_index*/, data_size_t start, data_size_t end) { for (size_t i = 0; i < feature_partition_column_index_offsets_.size() - 1; ++i) { const int num_prev_columns = static_cast(feature_partition_column_index_offsets_[i]); const size_t offset = static_cast(num_data_) * static_cast(num_prev_columns); const int partition_column_start = feature_partition_column_index_offsets_[i]; const int partition_column_end = feature_partition_column_index_offsets_[i + 1]; const int num_columns_in_cur_partition = partition_column_end - partition_column_start; for (data_size_t data_index = start; data_index < end; ++data_index) { const size_t data_offset = offset + static_cast(data_index) * num_columns_in_cur_partition; const size_t read_data_offset = static_cast(data_index) * num_total_columns; for (int column_index = 0; column_index < num_columns_in_cur_partition; ++column_index) { const size_t true_column_index = read_data_offset + column_index + partition_column_start; const BIN_TYPE bin = row_wise_data[true_column_index]; out_data[data_offset + column_index] = bin; } } } }); } template void CUDARowData::GetSparseDataPartitioned( const BIN_TYPE* row_wise_data, const DATA_PTR_TYPE* row_ptr, std::vector>* partitioned_data, std::vector>* partitioned_row_ptr, std::vector* partition_ptr) { const int num_partitions = static_cast(feature_partition_column_index_offsets_.size()) - 1; partitioned_data->resize(num_partitions); partitioned_row_ptr->resize(num_partitions); std::vector thread_max_elements_per_row(num_threads_, 0); Threading::For(0, num_partitions, 1, [partitioned_data, partitioned_row_ptr, row_ptr, row_wise_data, &thread_max_elements_per_row, this] (int thread_index, int start, int end) { for (int partition_index = start; partition_index < end; ++partition_index) { std::vector& data_for_this_partition = partitioned_data->at(partition_index); std::vector& row_ptr_for_this_partition = partitioned_row_ptr->at(partition_index); const int partition_hist_start = partition_hist_offsets_[partition_index]; const int partition_hist_end = partition_hist_offsets_[partition_index + 1]; DATA_PTR_TYPE offset = 0; row_ptr_for_this_partition.clear(); data_for_this_partition.clear(); row_ptr_for_this_partition.emplace_back(offset); for (data_size_t data_index = 0; data_index < num_data_; ++data_index) { const DATA_PTR_TYPE row_start = row_ptr[data_index]; const DATA_PTR_TYPE row_end = row_ptr[data_index + 1]; const BIN_TYPE* row_data_start = row_wise_data + row_start; const BIN_TYPE* row_data_end = row_wise_data + row_end; const size_t partition_start_in_row = std::lower_bound(row_data_start, row_data_end, partition_hist_start) - row_data_start; const size_t partition_end_in_row = std::lower_bound(row_data_start, row_data_end, partition_hist_end) - row_data_start; for (size_t pos = partition_start_in_row; pos < partition_end_in_row; ++pos) { const BIN_TYPE bin = row_data_start[pos]; CHECK_GE(bin, static_cast(partition_hist_start)); data_for_this_partition.emplace_back(bin - partition_hist_start); } CHECK_GE(partition_end_in_row, partition_start_in_row); const data_size_t num_elements_in_row = partition_end_in_row - partition_start_in_row; offset += static_cast(num_elements_in_row); row_ptr_for_this_partition.emplace_back(offset); if (num_elements_in_row > thread_max_elements_per_row[thread_index]) { thread_max_elements_per_row[thread_index] = num_elements_in_row; } } } }); partition_ptr->clear(); DATA_PTR_TYPE offset = 0; partition_ptr->emplace_back(offset); for (size_t i = 0; i < partitioned_row_ptr->size(); ++i) { offset += partitioned_row_ptr->at(i).back(); partition_ptr->emplace_back(offset); } max_num_column_per_partition_ = 0; for (int thread_index = 0; thread_index < num_threads_; ++thread_index) { if (thread_max_elements_per_row[thread_index] > max_num_column_per_partition_) { max_num_column_per_partition_ = thread_max_elements_per_row[thread_index]; } } } template void CUDARowData::InitSparseData(const BIN_TYPE* host_data, const ROW_PTR_TYPE* host_row_ptr, BIN_TYPE** cuda_data, ROW_PTR_TYPE** cuda_row_ptr, ROW_PTR_TYPE** cuda_partition_ptr) { std::vector> partitioned_data; std::vector> partitioned_data_ptr; std::vector partition_ptr; GetSparseDataPartitioned(host_data, host_row_ptr, &partitioned_data, &partitioned_data_ptr, &partition_ptr); InitCUDAMemoryFromHostMemory(cuda_partition_ptr, partition_ptr.data(), partition_ptr.size(), __FILE__, __LINE__); AllocateCUDAMemory(cuda_data, partition_ptr.back(), __FILE__, __LINE__); AllocateCUDAMemory(cuda_row_ptr, (num_data_ + 1) * partitioned_data_ptr.size(), __FILE__, __LINE__); for (size_t i = 0; i < partitioned_data.size(); ++i) { const std::vector& data_ptr_for_this_partition = partitioned_data_ptr[i]; const std::vector& data_for_this_partition = partitioned_data[i]; CopyFromHostToCUDADevice((*cuda_data) + partition_ptr[i], data_for_this_partition.data(), data_for_this_partition.size(), __FILE__, __LINE__); CopyFromHostToCUDADevice((*cuda_row_ptr) + i * (num_data_ + 1), data_ptr_for_this_partition.data(), data_ptr_for_this_partition.size(), __FILE__, __LINE__); } } template const BIN_TYPE* CUDARowData::GetBin() const { if (bit_type_ == 8) { return reinterpret_cast(cuda_data_uint8_t_); } else if (bit_type_ == 16) { return reinterpret_cast(cuda_data_uint16_t_); } else if (bit_type_ == 32) { return reinterpret_cast(cuda_data_uint32_t_); } else { Log::Fatal("Unknown bit_type %d for GetBin.", bit_type_); } } template const uint8_t* CUDARowData::GetBin() const; template const uint16_t* CUDARowData::GetBin() const; template const uint32_t* CUDARowData::GetBin() const; template const PTR_TYPE* CUDARowData::GetRowPtr() const { if (row_ptr_bit_type_ == 16) { return reinterpret_cast(cuda_row_ptr_uint16_t_); } else if (row_ptr_bit_type_ == 32) { return reinterpret_cast(cuda_row_ptr_uint32_t_); } else if (row_ptr_bit_type_ == 64) { return reinterpret_cast(cuda_row_ptr_uint64_t_); } else { Log::Fatal("Unknown row_ptr_bit_type = %d for GetRowPtr.", row_ptr_bit_type_); } } template const uint16_t* CUDARowData::GetRowPtr() const; template const uint32_t* CUDARowData::GetRowPtr() const; template const uint64_t* CUDARowData::GetRowPtr() const; template const PTR_TYPE* CUDARowData::GetPartitionPtr() const { if (row_ptr_bit_type_ == 16) { return reinterpret_cast(cuda_partition_ptr_uint16_t_); } else if (row_ptr_bit_type_ == 32) { return reinterpret_cast(cuda_partition_ptr_uint32_t_); } else if (row_ptr_bit_type_ == 64) { return reinterpret_cast(cuda_partition_ptr_uint64_t_); } else { Log::Fatal("Unknown row_ptr_bit_type = %d for GetPartitionPtr.", row_ptr_bit_type_); } } template const uint16_t* CUDARowData::GetPartitionPtr() const; template const uint32_t* CUDARowData::GetPartitionPtr() const; template const uint64_t* CUDARowData::GetPartitionPtr() const; } // namespace LightGBM #endif // USE_CUDA