/*! * Copyright (c) 2022 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. */ #ifdef USE_CUDA_EXP #include "cuda_binary_metric.hpp" #include "cuda_pointwise_metric.hpp" #include "cuda_regression_metric.hpp" namespace LightGBM { template void CUDAPointwiseMetricInterface::Init(const Metadata& metadata, data_size_t num_data) { CUDAMetricInterface::Init(metadata, num_data); const int max_num_reduce_blocks = (this->num_data_ + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD; if (this->cuda_weights_ == nullptr) { reduce_block_buffer_.Resize(max_num_reduce_blocks); } else { reduce_block_buffer_.Resize(max_num_reduce_blocks * 2); } const int max_num_reduce_blocks_inner = (max_num_reduce_blocks + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD; if (this->cuda_weights_ == nullptr) { reduce_block_buffer_inner_.Resize(max_num_reduce_blocks_inner); } else { reduce_block_buffer_inner_.Resize(max_num_reduce_blocks_inner * 2); } } template void CUDAPointwiseMetricInterface::Init(const Metadata& metadata, data_size_t num_data); template void CUDAPointwiseMetricInterface::Init(const Metadata& metadata, data_size_t num_data); template void CUDAPointwiseMetricInterface::Init(const Metadata& metadata, data_size_t num_data); } // namespace LightGBM #endif // USE_CUDA_EXP