#include #include using namespace LightGBM; #include #include #include #include namespace { PredictionEarlyStopInstance createNone(const PredictionEarlyStopConfig&) { return PredictionEarlyStopInstance{ [](const double*, int) { return false; }, std::numeric_limits::max() // make sure the lambda is almost never called }; } PredictionEarlyStopInstance createMulticlass(const PredictionEarlyStopConfig& config) { // marginThreshold will be captured by value const double marginThreshold = config.marginThreshold; return PredictionEarlyStopInstance{ [marginThreshold](const double* pred, int sz) { if(sz < 2) { Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger"); } // copy and sort std::vector votes(static_cast(sz)); for (int i=0; i < sz; ++i) { votes[i] = pred[i]; } std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater()); const auto margin = votes[0] - votes[1]; if (margin > marginThreshold) { return true; } return false; }, config.roundPeriod }; } PredictionEarlyStopInstance createBinary(const PredictionEarlyStopConfig& config) { // marginThreshold will be captured by value const double marginThreshold = config.marginThreshold; return PredictionEarlyStopInstance{ [marginThreshold](const double* pred, int sz) { if(sz != 1) { Log::Fatal("Binary early stopping needs predictions to be of length one"); } const auto margin = 2.0 * fabs(pred[0]); if (margin > marginThreshold) { return true; } return false; }, config.roundPeriod }; } } namespace LightGBM { PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type, const PredictionEarlyStopConfig& config) { if (type == "none") { return createNone(config); } else if (type == "multiclass") { return createMulticlass(config); } else if (type == "binary") { return createBinary(config); } else { throw std::runtime_error("Unknown early stopping type: " + type); } } }