/*! * Copyright (c) 2022 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. */ #ifndef LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_ #ifdef USE_CUDA #include #include #include #define NUM_DATA_PER_EVAL_THREAD (1024) namespace LightGBM { template class CUDAPointwiseMetricInterface: public CUDAMetricInterface { public: explicit CUDAPointwiseMetricInterface(const Config& config): CUDAMetricInterface(config), num_class_(config.num_class) {} virtual ~CUDAPointwiseMetricInterface() {} void Init(const Metadata& metadata, data_size_t num_data) override; protected: void LaunchEvalKernel(const double* score_convert, double* sum_loss, double* sum_weight) const; virtual double GetParamFromConfig() const { return 0.0; } mutable CUDAVector score_convert_buffer_; CUDAVector reduce_block_buffer_; CUDAVector reduce_block_buffer_inner_; const int num_class_; }; } // namespace LightGBM #endif // USE_CUDA #endif // LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_