/*! * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_ #define LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_ #ifdef USE_CUDA_EXP #include #include #include #include #include "cuda_binary_objective.hpp" #include "../multiclass_objective.hpp" #define GET_GRADIENTS_BLOCK_SIZE_MULTICLASS (1024) namespace LightGBM { class CUDAMulticlassSoftmax: public CUDAObjectiveInterface { public: explicit CUDAMulticlassSoftmax(const Config& config); explicit CUDAMulticlassSoftmax(const std::vector& strs); ~CUDAMulticlassSoftmax(); void Init(const Metadata& metadata, data_size_t num_data) override; private: void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const; const double* LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const; // CUDA memory, held by this object CUDAVector cuda_softmax_buffer_; }; class CUDAMulticlassOVA: public CUDAObjectiveInterface { public: explicit CUDAMulticlassOVA(const Config& config); explicit CUDAMulticlassOVA(const std::vector& strs); void Init(const Metadata& metadata, data_size_t num_data) override; void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override; const double* ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override; double BoostFromScore(int class_id) const override { return cuda_binary_loss_[class_id]->BoostFromScore(0); } bool ClassNeedTrain(int class_id) const override { return cuda_binary_loss_[class_id]->ClassNeedTrain(0); } ~CUDAMulticlassOVA(); bool IsCUDAObjective() const override { return true; } private: void LaunchGetGradientsKernel(const double* /*scores*/, score_t* /*gradients*/, score_t* /*hessians*/) const {} std::vector> cuda_binary_loss_; }; } // namespace LightGBM #endif // USE_CUDA_EXP #endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_