/*! * Copyright (c) 2022 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. */ #ifdef USE_CUDA_EXP #include #include "cuda_binary_metric.hpp" #include "cuda_pointwise_metric.hpp" #include "cuda_regression_metric.hpp" namespace LightGBM { template __global__ void EvalKernel(const data_size_t num_data, const label_t* labels, const label_t* weights, const double* scores, double* reduce_block_buffer) { __shared__ double shared_mem_buffer[32]; const data_size_t index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); double point_metric = 0.0; if (index < num_data) { point_metric = USE_WEIGHTS ? CUDA_METRIC::MetricOnPointCUDA(labels[index], scores[index]) * weights[index] : CUDA_METRIC::MetricOnPointCUDA(labels[index], scores[index]); } const double block_sum_point_metric = ShuffleReduceSum(point_metric, shared_mem_buffer, NUM_DATA_PER_EVAL_THREAD); if (threadIdx.x == 0) { reduce_block_buffer[blockIdx.x] = block_sum_point_metric; } if (USE_WEIGHTS) { double weight = 0.0; if (index < num_data) { weight = static_cast(weights[index]); const double block_sum_weight = ShuffleReduceSum(weight, shared_mem_buffer, NUM_DATA_PER_EVAL_THREAD); if (threadIdx.x == 0) { reduce_block_buffer[blockIdx.x + gridDim.x] = block_sum_weight; } } } } template void CUDAPointwiseMetricInterface::LaunchEvalKernel(const double* score, double* sum_loss, double* sum_weight) const { const int num_blocks = (this->num_data_ + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD; if (this->cuda_weights_ != nullptr) { EvalKernel<<>>( this->num_data_, this->cuda_labels_, this->cuda_weights_, score, reduce_block_buffer_.RawData()); } else { EvalKernel<<>>( this->num_data_, this->cuda_labels_, this->cuda_weights_, score, reduce_block_buffer_.RawData()); } ShuffleReduceSumGlobal(reduce_block_buffer_.RawData(), num_blocks, reduce_block_buffer_inner_.RawData()); CopyFromCUDADeviceToHost(sum_loss, reduce_block_buffer_inner_.RawData(), 1, __FILE__, __LINE__); *sum_weight = static_cast(this->num_data_); if (this->cuda_weights_ != nullptr) { ShuffleReduceSumGlobal(reduce_block_buffer_.RawData() + num_blocks, num_blocks, reduce_block_buffer_inner_.RawData()); CopyFromCUDADeviceToHost(sum_weight, reduce_block_buffer_inner_.RawData(), 1, __FILE__, __LINE__); } } template void CUDAPointwiseMetricInterface::LaunchEvalKernel(const double* score, double* sum_loss, double* sum_weight) const; template void CUDAPointwiseMetricInterface::LaunchEvalKernel(const double* score, double* sum_loss, double* sum_weight) const; template void CUDAPointwiseMetricInterface::LaunchEvalKernel(const double* score, double* sum_loss, double* sum_weight) const; } // namespace LightGBM #endif // USE_CUDA_EXP