Commit 548cec82 authored by Jeff Daily's avatar Jeff Daily
Browse files

Merge branch 'master' into rocm3

parents 2f7bd8ef 5dbfcdc4
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2020 Microsoft Corporation. All rights reserved. * Copyright (c) 2020 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_IO_MULTI_VAL_SPARSE_BIN_HPP_ #ifndef LIGHTGBM_SRC_IO_MULTI_VAL_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_MULTI_VAL_SPARSE_BIN_HPP_ #define LIGHTGBM_SRC_IO_MULTI_VAL_SPARSE_BIN_HPP_
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h> #include <LightGBM/utils/openmp_wrapper.h>
...@@ -445,4 +445,4 @@ MultiValSparseBin<INDEX_T, VAL_T>* MultiValSparseBin<INDEX_T, VAL_T>::Clone() { ...@@ -445,4 +445,4 @@ MultiValSparseBin<INDEX_T, VAL_T>* MultiValSparseBin<INDEX_T, VAL_T>::Clone() {
} // namespace LightGBM } // namespace LightGBM
#endif // LIGHTGBM_IO_MULTI_VAL_SPARSE_BIN_HPP_ #endif // LIGHTGBM_SRC_IO_MULTI_VAL_SPARSE_BIN_HPP_
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
*/ */
#include "parser.hpp" #include "parser.hpp"
#include <functional>
#include <string>
#include <algorithm> #include <algorithm>
#include <functional>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string>
#include <vector>
namespace LightGBM { namespace LightGBM {
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_IO_PARSER_HPP_ #ifndef LIGHTGBM_SRC_IO_PARSER_HPP_
#define LIGHTGBM_IO_PARSER_HPP_ #define LIGHTGBM_SRC_IO_PARSER_HPP_
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
...@@ -132,4 +132,4 @@ class LibSVMParser: public Parser { ...@@ -132,4 +132,4 @@ class LibSVMParser: public Parser {
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_IO_PARSER_HPP_ #endif // LIGHTGBM_SRC_IO_PARSER_HPP_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* Licensed under the MIT License. See LICENSE file in the project root for * Licensed under the MIT License. See LICENSE file in the project root for
* license information. * license information.
*/ */
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_ #ifndef LIGHTGBM_SRC_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_ #define LIGHTGBM_SRC_IO_SPARSE_BIN_HPP_
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
...@@ -854,4 +854,4 @@ BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, ...@@ -854,4 +854,4 @@ BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_IO_SPARSE_BIN_HPP_ #endif // LIGHTGBM_SRC_IO_SPARSE_BIN_HPP_
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
#include <LightGBM/train_share_states.h> #include <LightGBM/train_share_states.h>
#include <algorithm>
#include <memory>
#include <vector>
namespace LightGBM { namespace LightGBM {
MultiValBinWrapper::MultiValBinWrapper(MultiValBin* bin, data_size_t num_data, MultiValBinWrapper::MultiValBinWrapper(MultiValBin* bin, data_size_t num_data,
......
...@@ -8,9 +8,13 @@ ...@@ -8,9 +8,13 @@
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h> #include <LightGBM/utils/threading.h>
#include <algorithm>
#include <functional> #include <functional>
#include <iomanip> #include <iomanip>
#include <sstream> #include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace LightGBM { namespace LightGBM {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <LightGBM/application.h> #include <LightGBM/application.h>
#include <iostream> #include <iostream>
#include <string>
#ifdef USE_MPI #ifdef USE_MPI
#include "network/linkers.h" #include "network/linkers.h"
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_METRIC_BINARY_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_BINARY_METRIC_HPP_
#define LIGHTGBM_METRIC_BINARY_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_BINARY_METRIC_HPP_
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
...@@ -385,4 +385,4 @@ class AveragePrecisionMetric: public Metric { ...@@ -385,4 +385,4 @@ class AveragePrecisionMetric: public Metric {
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_METRIC_BINARY_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_BINARY_METRIC_HPP_
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "cuda_binary_metric.hpp" #include "cuda_binary_metric.hpp"
#include <vector>
namespace LightGBM { namespace LightGBM {
CUDABinaryLoglossMetric::CUDABinaryLoglossMetric(const Config& config): CUDABinaryLoglossMetric::CUDABinaryLoglossMetric(
CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>(config) {} const Config& config):CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>(config) {}
template <typename HOST_METRIC, typename CUDA_METRIC> template <typename HOST_METRIC, typename CUDA_METRIC>
std::vector<double> CUDABinaryMetricInterface<HOST_METRIC, CUDA_METRIC>::Eval(const double* score, const ObjectiveFunction* objective) const { std::vector<double> CUDABinaryMetricInterface<HOST_METRIC, CUDA_METRIC>::Eval(const double* score, const ObjectiveFunction* objective) const {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* license information. * license information.
*/ */
#ifndef LIGHTGBM_METRIC_CUDA_CUDA_BINARY_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_CUDA_CUDA_BINARY_METRIC_HPP_
#define LIGHTGBM_METRIC_CUDA_CUDA_BINARY_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_CUDA_CUDA_BINARY_METRIC_HPP_
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -54,4 +54,4 @@ class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMet ...@@ -54,4 +54,4 @@ class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMet
#endif // USE_CUDA #endif // USE_CUDA
#endif // LIGHTGBM_METRIC_CUDA_CUDA_BINARY_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_CUDA_CUDA_BINARY_METRIC_HPP_
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* license information. * license information.
*/ */
#ifndef LIGHTGBM_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_
#define LIGHTGBM_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -42,4 +42,4 @@ class CUDAPointwiseMetricInterface: public CUDAMetricInterface<HOST_METRIC> { ...@@ -42,4 +42,4 @@ class CUDAPointwiseMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
#endif // USE_CUDA #endif // USE_CUDA
#endif // LIGHTGBM_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* license information. * license information.
*/ */
#ifndef LIGHTGBM_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_
#define LIGHTGBM_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -212,4 +212,4 @@ class CUDATweedieMetric : public CUDARegressionMetricInterface<TweedieMetric, CU ...@@ -212,4 +212,4 @@ class CUDATweedieMetric : public CUDARegressionMetricInterface<TweedieMetric, CU
#endif // USE_CUDA #endif // USE_CUDA
#endif // LIGHTGBM_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2017 Microsoft Corporation. All rights reserved. * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_METRIC_MAP_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_MAP_METRIC_HPP_
#define LIGHTGBM_METRIC_MAP_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_MAP_METRIC_HPP_
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
...@@ -165,4 +165,4 @@ class MapMetric:public Metric { ...@@ -165,4 +165,4 @@ class MapMetric:public Metric {
} // namespace LightGBM } // namespace LightGBM
#endif // LIGHTGBM_METRIC_MAP_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_MAP_METRIC_HPP_
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
*/ */
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <string>
#include "binary_metric.hpp" #include "binary_metric.hpp"
#include "map_metric.hpp" #include "map_metric.hpp"
#include "multiclass_metric.hpp" #include "multiclass_metric.hpp"
...@@ -76,6 +78,9 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) { ...@@ -76,6 +78,9 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
return new CUDAGammaDevianceMetric(config); return new CUDAGammaDevianceMetric(config);
} else if (type == std::string("tweedie")) { } else if (type == std::string("tweedie")) {
return new CUDATweedieMetric(config); return new CUDATweedieMetric(config);
} else if (type == std::string("r2")) {
Log::Warning("Metric r2 is not implemented in cuda version. Fall back to evaluation on CPU.");
return new R2Metric(config);
} }
} else { } else {
#endif // USE_CUDA #endif // USE_CUDA
...@@ -125,6 +130,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) { ...@@ -125,6 +130,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
return new GammaDevianceMetric(config); return new GammaDevianceMetric(config);
} else if (type == std::string("tweedie")) { } else if (type == std::string("tweedie")) {
return new TweedieMetric(config); return new TweedieMetric(config);
} else if (type == std::string("r2")) {
return new R2Metric(config);
} }
#ifdef USE_CUDA #ifdef USE_CUDA
} }
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_MULTICLASS_METRIC_HPP_
#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_MULTICLASS_METRIC_HPP_
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
...@@ -237,7 +237,7 @@ class AucMuMetric : public Metric { ...@@ -237,7 +237,7 @@ class AucMuMetric : public Metric {
std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
// the notation follows that used in the paper introducing the auc-mu metric: // the notation follows that used in the paper introducing the auc-mu metric:
// http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf // https://proceedings.mlr.press/v97/kleiman19a.html
auto S = std::vector<std::vector<double>>(num_class_, std::vector<double>(num_class_, 0)); auto S = std::vector<std::vector<double>>(num_class_, std::vector<double>(num_class_, 0));
int i_start = 0; int i_start = 0;
...@@ -365,4 +365,4 @@ class AucMuMetric : public Metric { ...@@ -365,4 +365,4 @@ class AucMuMetric : public Metric {
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_MULTICLASS_METRIC_HPP_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_METRIC_RANK_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_RANK_METRIC_HPP_
#define LIGHTGBM_METRIC_RANK_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_RANK_METRIC_HPP_
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
...@@ -166,4 +166,4 @@ class NDCGMetric:public Metric { ...@@ -166,4 +166,4 @@ class NDCGMetric:public Metric {
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_METRIC_RANK_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_RANK_METRIC_HPP_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_METRIC_REGRESSION_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_REGRESSION_METRIC_HPP_
#define LIGHTGBM_METRIC_REGRESSION_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_REGRESSION_METRIC_HPP_
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
...@@ -318,5 +318,115 @@ class TweedieMetric : public RegressionMetric<TweedieMetric> { ...@@ -318,5 +318,115 @@ class TweedieMetric : public RegressionMetric<TweedieMetric> {
}; };
class R2Metric: public Metric {
public:
explicit R2Metric(const Config& config) :config_(config) {}
const std::vector<std::string>& GetName() const override {
return name_;
}
double factor_to_bigger_better() const override {
return 1.0f;
}
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("r2");
num_data_ = num_data;
label_ = metadata.label();
weights_ = metadata.weights();
double sum_label = 0.0f;
if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum_label)
for (data_size_t i = 0; i < num_data_; ++i) {
sum_label += label_[i];
}
} else {
double local_sum_weights = 0.0f;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:local_sum_weights, sum_label)
for (data_size_t i = 0; i < num_data_; ++i) {
local_sum_weights += weights_[i];
sum_label += label_[i] * weights_[i];
}
sum_weights_ = local_sum_weights;
}
label_mean_ = sum_label / sum_weights_;
total_sum_squares_ = 0.0f;
double local_total_sum_squares = 0.0f;
if (weights_ == nullptr) {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:local_total_sum_squares)
for (data_size_t i = 0; i < num_data_; ++i) {
double diff = label_[i] - label_mean_;
local_total_sum_squares += diff * diff;
}
} else {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:local_total_sum_squares)
for (data_size_t i = 0; i < num_data_; ++i) {
double diff = label_[i] - label_mean_;
local_total_sum_squares += diff * diff * weights_[i];
}
}
total_sum_squares_ = local_total_sum_squares;
}
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
double residual_sum_squares = 0.0f;
if (objective == nullptr) {
if (weights_ == nullptr) {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares)
for (data_size_t i = 0; i < num_data_; ++i) {
double diff = label_[i] - score[i];
residual_sum_squares += diff * diff;
}
} else {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares)
for (data_size_t i = 0; i < num_data_; ++i) {
double diff = label_[i] - score[i];
residual_sum_squares += diff * diff * weights_[i];
}
}
} else {
if (weights_ == nullptr) {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares)
for (data_size_t i = 0; i < num_data_; ++i) {
double t = 0;
objective->ConvertOutput(&score[i], &t);
double diff = label_[i] - t;
residual_sum_squares += diff * diff;
}
} else {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares)
for (data_size_t i = 0; i < num_data_; ++i) {
double t = 0;
objective->ConvertOutput(&score[i], &t);
double diff = label_[i] - t;
residual_sum_squares += diff * diff * weights_[i];
}
}
}
double r2 = 1.0 - (residual_sum_squares / total_sum_squares_);
if (std::fabs(total_sum_squares_) < kZeroThreshold) {
return std::vector<double>(1, std::fabs(residual_sum_squares) < kZeroThreshold ? 1.0 : 0.0);
}
return std::vector<double>(1, r2);
}
protected:
data_size_t num_data_;
const label_t* label_;
const label_t* weights_;
double sum_weights_;
Config config_;
std::vector<std::string> name_;
// Custom members for R2 calculation
double label_mean_;
double total_sum_squares_;
};
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_METRIC_REGRESSION_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_REGRESSION_METRIC_HPP_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2017 Microsoft Corporation. All rights reserved. * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_METRIC_XENTROPY_METRIC_HPP_ #ifndef LIGHTGBM_SRC_METRIC_XENTROPY_METRIC_HPP_
#define LIGHTGBM_METRIC_XENTROPY_METRIC_HPP_ #define LIGHTGBM_SRC_METRIC_XENTROPY_METRIC_HPP_
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
...@@ -30,40 +30,40 @@ ...@@ -30,40 +30,40 @@
namespace LightGBM { namespace LightGBM {
// label should be in interval [0, 1]; // label should be in interval [0, 1];
// prob should be in interval (0, 1); prob is clipped if needed // prob should be in interval (0, 1); prob is clipped if needed
inline static double XentLoss(label_t label, double prob) { inline static double XentLoss(label_t label, double prob) {
const double log_arg_epsilon = 1.0e-12; const double log_arg_epsilon = 1.0e-12;
double a = label; double a = label;
if (prob > log_arg_epsilon) { if (prob > log_arg_epsilon) {
a *= std::log(prob); a *= std::log(prob);
} else { } else {
a *= std::log(log_arg_epsilon); a *= std::log(log_arg_epsilon);
}
double b = 1.0f - label;
if (1.0f - prob > log_arg_epsilon) {
b *= std::log(1.0f - prob);
} else {
b *= std::log(log_arg_epsilon);
}
return - (a + b);
} }
double b = 1.0f - label;
// hhat >(=) 0 assumed; and weight > 0 required; but not checked here if (1.0f - prob > log_arg_epsilon) {
inline static double XentLambdaLoss(label_t label, label_t weight, double hhat) { b *= std::log(1.0f - prob);
return XentLoss(label, 1.0f - std::exp(-weight * hhat)); } else {
} b *= std::log(log_arg_epsilon);
// Computes the (negative) entropy for label p; p should be in interval [0, 1];
// This is used to presum the KL-divergence offset term (to be _added_ to the cross-entropy loss).
// NOTE: x*log(x) = 0 for x=0,1; so only add when in (0, 1); avoid log(0)*0
inline static double YentLoss(double p) {
double hp = 0.0;
if (p > 0) hp += p * std::log(p);
double q = 1.0f - p;
if (q > 0) hp += q * std::log(q);
return hp;
} }
return - (a + b);
}
// hhat >(=) 0 assumed; and weight > 0 required; but not checked here
inline static double XentLambdaLoss(label_t label, label_t weight, double hhat) {
return XentLoss(label, 1.0f - std::exp(-weight * hhat));
}
// Computes the (negative) entropy for label p; p should be in interval [0, 1];
// This is used to presum the KL-divergence offset term (to be _added_ to the cross-entropy loss).
// NOTE: x*log(x) = 0 for x=0,1; so only add when in (0, 1); avoid log(0)*0
inline static double YentLoss(double p) {
double hp = 0.0;
if (p > 0) hp += p * std::log(p);
double q = 1.0f - p;
if (q > 0) hp += q * std::log(q);
return hp;
}
// //
// CrossEntropyMetric : "xentropy" : (optional) weights are used linearly // CrossEntropyMetric : "xentropy" : (optional) weights are used linearly
...@@ -355,4 +355,4 @@ class KullbackLeiblerDivergence : public Metric { ...@@ -355,4 +355,4 @@ class KullbackLeiblerDivergence : public Metric {
} // end namespace LightGBM } // end namespace LightGBM
#endif // end #ifndef LIGHTGBM_METRIC_XENTROPY_METRIC_HPP_ #endif // LIGHTGBM_SRC_METRIC_XENTROPY_METRIC_HPP_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2016 Microsoft Corporation. All rights reserved. * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
*/ */
#ifndef LIGHTGBM_NETWORK_LINKERS_H_ #ifndef LIGHTGBM_SRC_NETWORK_LINKERS_H_
#define LIGHTGBM_NETWORK_LINKERS_H_ #define LIGHTGBM_SRC_NETWORK_LINKERS_H_
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
...@@ -325,4 +325,4 @@ inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len, ...@@ -325,4 +325,4 @@ inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len,
#endif // USE_MPI #endif // USE_MPI
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_NETWORK_LINKERS_H_ #endif // LIGHTGBM_SRC_NETWORK_LINKERS_H_
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "linkers.h" #include "linkers.h"
#include <iostream>
namespace LightGBM { namespace LightGBM {
Linkers::Linkers(Config) { Linkers::Linkers(Config) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment