Unverified Commit 716fe4d0 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

fixed cpplint errors about spaces and indents (#2282)

parent ee28ea36
...@@ -447,7 +447,7 @@ struct Config { ...@@ -447,7 +447,7 @@ struct Config {
// default = None // default = None
// desc = max number of bins for each feature // desc = max number of bins for each feature
// desc = if not specified, will use ``max_bin`` for all features // desc = if not specified, will use ``max_bin`` for all features
std::vector<int32_t> max_bin_by_feature; std::vector<int32_t> max_bin_by_feature;
// check = >0 // check = >0
// desc = minimal number of data inside one bin // desc = minimal number of data inside one bin
......
...@@ -405,7 +405,7 @@ class Tree { ...@@ -405,7 +405,7 @@ class Tree {
}; };
inline void Tree::Split(int leaf, int feature, int real_feature, inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt, double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain) { double left_weight, double right_weight, float gain) {
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
// update parent info // update parent info
......
...@@ -136,7 +136,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -136,7 +136,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
valid_metrics_.back().push_back(metric); valid_metrics_.back().push_back(metric);
} }
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
if (early_stopping_round_ > 0) { if (early_stopping_round_ > 0) {
auto num_metrics = valid_metrics.size(); auto num_metrics = valid_metrics.size();
if (es_first_metric_only_) { num_metrics = 1; } if (es_first_metric_only_) { num_metrics = 1; }
...@@ -739,7 +739,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { ...@@ -739,7 +739,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
} }
if (balance_bagging_cond) { if (balance_bagging_cond) {
balanced_bagging_ = true; balanced_bagging_ = true;
bag_data_cnt_ = static_cast<data_size_t>(num_pos_data * config->pos_bagging_fraction) bag_data_cnt_ = static_cast<data_size_t>(num_pos_data * config->pos_bagging_fraction)
+ static_cast<data_size_t>((num_data_ - num_pos_data) * config->neg_bagging_fraction); + static_cast<data_size_t>((num_data_ - num_pos_data) * config->neg_bagging_fraction);
} else { } else {
bag_data_cnt_ = static_cast<data_size_t>(config->bagging_fraction * num_data_); bag_data_cnt_ = static_cast<data_size_t>(config->bagging_fraction * num_data_);
......
...@@ -711,11 +711,11 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -711,11 +711,11 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
} }
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
int num_rows, int num_rows,
int64_t num_col, int64_t num_col,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr); auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
...@@ -767,10 +767,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, ...@@ -767,10 +767,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
for (int i = 0; i < num_rows; ++i) { for (int i = 0; i < num_rows; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
{ {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
get_row_fun(i, threadBuffer); get_row_fun(i, threadBuffer);
ret->PushOneRow(tid, i, threadBuffer);
ret->PushOneRow(tid, i, threadBuffer);
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
...@@ -1291,19 +1290,19 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1291,19 +1290,19 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
} }
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t, int64_t,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
auto param = Config::Str2Map(parameter); auto param = Config::Str2Map(parameter);
Config config; Config config;
...@@ -1313,8 +1312,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -1313,8 +1312,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1397,15 +1395,15 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1397,15 +1395,15 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
} }
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
const void* data, const void* data,
int data_type, int data_type,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
auto param = Config::Str2Map(parameter); auto param = Config::Str2Map(parameter);
Config config; Config config;
...@@ -1415,8 +1413,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1415,8 +1413,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1440,8 +1437,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -1440,8 +1437,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type); auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, config, out_result, out_len);
config, out_result, out_len);
API_END(); API_END();
} }
......
...@@ -172,7 +172,7 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -172,7 +172,7 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
GetTreeLearnerType(params, &tree_learner); GetTreeLearnerType(params, &tree_learner);
GetMembersFromString(params); GetMembersFromString(params);
// sort eval_at // sort eval_at
std::sort(eval_at.begin(), eval_at.end()); std::sort(eval_at.begin(), eval_at.end());
......
...@@ -584,13 +584,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -584,13 +584,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) { if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.max_bin, config_.min_data_in_bin, filter_cnt,
bin_type, config_.use_missing, config_.zero_as_missing); bin_type, config_.use_missing, config_.zero_as_missing);
} else { } else {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin_by_feature[i], config_.min_data_in_bin, config_.max_bin_by_feature[i], config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing); config_.zero_as_missing);
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
...@@ -628,13 +628,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -628,13 +628,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) { if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin, config_.min_data_in_bin, total_sample_size, config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else { } else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin_by_feature[start[rank] + i], total_sample_size, config_.max_bin_by_feature[start[rank] + i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing); config_.zero_as_missing);
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
...@@ -908,12 +908,12 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -908,12 +908,12 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) { if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()), bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin, sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else { } else {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()), bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin_by_feature[i], sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing); config_.zero_as_missing);
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
...@@ -952,16 +952,16 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -952,16 +952,16 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) { if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()), static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin, sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else { } else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()), static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin_by_feature[i], sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.min_data_in_bin, filter_cnt, bin_type,
config_.use_missing, config_.zero_as_missing); config_.use_missing, config_.zero_as_missing);
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
......
...@@ -515,7 +515,7 @@ Tree::Tree(const char* str, size_t* used_len) { ...@@ -515,7 +515,7 @@ Tree::Tree(const char* str, size_t* used_len) {
} else { } else {
Log::Fatal("Tree model string format error, should contain leaf_value field"); Log::Fatal("Tree model string format error, should contain leaf_value field");
} }
if (key_vals.count("shrinkage")) { if (key_vals.count("shrinkage")) {
Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_); Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
} else { } else {
...@@ -568,15 +568,13 @@ Tree::Tree(const char* str, size_t* used_len) { ...@@ -568,15 +568,13 @@ Tree::Tree(const char* str, size_t* used_len) {
if (key_vals.count("internal_weight")) { if (key_vals.count("internal_weight")) {
internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1); internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
} } else {
else {
internal_weight_.resize(num_leaves_ - 1); internal_weight_.resize(num_leaves_ - 1);
} }
if (key_vals.count("leaf_weight")) { if (key_vals.count("leaf_weight")) {
leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_); leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
} } else {
else {
leaf_weight_.resize(num_leaves_); leaf_weight_.resize(num_leaves_);
} }
......
...@@ -20,7 +20,7 @@ namespace LightGBM { ...@@ -20,7 +20,7 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric { class MulticlassMetric: public Metric {
public: public:
explicit MulticlassMetric(const Config& config) :config_(config){ explicit MulticlassMetric(const Config& config) :config_(config) {
num_class_ = config.num_class; num_class_ = config.num_class;
} }
...@@ -149,8 +149,11 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> { ...@@ -149,8 +149,11 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
} }
inline static const std::string Name(const Config& config) { inline static const std::string Name(const Config& config) {
if (config.multi_error_top_k == 1) return "multi_error"; if (config.multi_error_top_k == 1) {
else return "multi_error@" + std::to_string(config.multi_error_top_k); return "multi_error";
} else {
return "multi_error@" + std::to_string(config.multi_error_top_k);
}
} }
}; };
......
...@@ -239,23 +239,23 @@ class RegressionL1loss: public RegressionL2loss { ...@@ -239,23 +239,23 @@ class RegressionL1loss: public RegressionL2loss {
const double alpha = 0.5; const double alpha = 0.5;
if (weights_ == nullptr) { if (weights_ == nullptr) {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (residual_getter(label_,index_mapper[i])) #define data_reader(i) (residual_getter(label_, index_mapper[i]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha); PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
} else { } else {
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]])) #define data_reader(i) (residual_getter(label_, bagging_mapper[index_mapper[i]]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha); PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
} }
} else { } else {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (residual_getter(label_,index_mapper[i])) #define data_reader(i) (residual_getter(label_, index_mapper[i]))
#define weight_reader(i) (weights_[index_mapper[i]]) #define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
#undef weight_reader #undef weight_reader
} else { } else {
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]])) #define data_reader(i) (residual_getter(label_, bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]]) #define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
...@@ -526,23 +526,23 @@ class RegressionQuantileloss : public RegressionL2loss { ...@@ -526,23 +526,23 @@ class RegressionQuantileloss : public RegressionL2loss {
data_size_t num_data_in_leaf) const override { data_size_t num_data_in_leaf) const override {
if (weights_ == nullptr) { if (weights_ == nullptr) {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (residual_getter(label_,index_mapper[i])) #define data_reader(i) (residual_getter(label_, index_mapper[i]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha_); PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
} else { } else {
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]])) #define data_reader(i) (residual_getter(label_, bagging_mapper[index_mapper[i]]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha_); PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
} }
} else { } else {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (residual_getter(label_,index_mapper[i])) #define data_reader(i) (residual_getter(label_, index_mapper[i]))
#define weight_reader(i) (weights_[index_mapper[i]]) #define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
#undef weight_reader #undef weight_reader
} else { } else {
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]])) #define data_reader(i) (residual_getter(label_, bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]]) #define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
...@@ -627,13 +627,13 @@ class RegressionMAPELOSS : public RegressionL1loss { ...@@ -627,13 +627,13 @@ class RegressionMAPELOSS : public RegressionL1loss {
data_size_t num_data_in_leaf) const override { data_size_t num_data_in_leaf) const override {
const double alpha = 0.5; const double alpha = 0.5;
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (residual_getter(label_,index_mapper[i])) #define data_reader(i) (residual_getter(label_, index_mapper[i]))
#define weight_reader(i) (label_weight_[index_mapper[i]]) #define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
#undef weight_reader #undef weight_reader
} else { } else {
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]])) #define data_reader(i) (residual_getter(label_, bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]]) #define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
......
/*!
* Copyright (c) 2018 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
/* lightgbmlib.i */ /* lightgbmlib.i */
%module lightgbmlib %module lightgbmlib
%ignore LGBM_BoosterSaveModelToString; %ignore LGBM_BoosterSaveModelToString;
...@@ -23,8 +27,8 @@ ...@@ -23,8 +27,8 @@
%include "../include/LightGBM/export.h" %include "../include/LightGBM/export.h"
%include "../include/LightGBM/c_api.h" %include "../include/LightGBM/c_api.h"
%typemap(in, numinputs=0) JNIEnv *jenv %{ %typemap(in, numinputs = 0) JNIEnv *jenv %{
$1 = jenv; $1 = jenv;
%} %}
%inline %{ %inline %{
...@@ -59,9 +63,9 @@ ...@@ -59,9 +63,9 @@
return nullptr; return nullptr;
} }
return dst; return dst;
} }
int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv, int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv,
jdoubleArray data, jdoubleArray data,
BoosterHandle handle, BoosterHandle handle,
int data_type, int data_type,
...@@ -73,53 +77,53 @@ ...@@ -73,53 +77,53 @@
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0); double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type, int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type,
num_iteration, parameter, out_len, out_result); num_iteration, parameter, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
return ret; return ret;
} }
int LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv, int LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv,
jintArray indices, jintArray indices,
jdoubleArray values, jdoubleArray values,
int numNonZeros, int numNonZeros,
BoosterHandle handle, BoosterHandle handle,
int indptr_type, int indptr_type,
int data_type, int data_type,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
// Alternatives // Alternatives
// - GetIntArrayElements: performs copy // - GetIntArrayElements: performs copy
// - GetDirectBufferAddress: fails on wrapped array // - GetDirectBufferAddress: fails on wrapped array
// Some words of warning for GetPrimitiveArrayCritical // Some words of warning for GetPrimitiveArrayCritical
// https://stackoverflow.com/questions/23258357/whats-the-trade-off-between-using-getprimitivearraycritical-and-getprimitivety // https://stackoverflow.com/questions/23258357/whats-the-trade-off-between-using-getprimitivearraycritical-and-getprimitivety
jboolean isCopy; jboolean isCopy;
int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, &isCopy); int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, &isCopy);
double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, &isCopy); double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, &isCopy);
int32_t ind[2] = { 0, numNonZeros }; int32_t ind[2] = { 0, numNonZeros };
int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2, int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2,
nelem, num_col, predict_type, num_iteration, parameter, out_len, out_result); nelem, num_col, predict_type, num_iteration, parameter, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
return ret; return ret;
} }
#include <vector>
#include <functional> #include <functional>
#include <vector>
struct CSRDirect { struct CSRDirect {
jintArray indices; jintArray indices;
jdoubleArray values; jdoubleArray values;
...@@ -127,7 +131,7 @@ ...@@ -127,7 +131,7 @@
double* values0; double* values0;
int size; int size;
}; };
int LGBM_DatasetCreateFromCSRSpark(JNIEnv *jenv, int LGBM_DatasetCreateFromCSRSpark(JNIEnv *jenv,
jobjectArray arrayOfSparseVector, jobjectArray arrayOfSparseVector,
int num_rows, int num_rows,
...@@ -135,25 +139,25 @@ ...@@ -135,25 +139,25 @@
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
jclass sparseVectorClass = jenv->FindClass("org/apache/spark/ml/linalg/SparseVector"); jclass sparseVectorClass = jenv->FindClass("org/apache/spark/ml/linalg/SparseVector");
jmethodID sparseVectorIndices = jenv->GetMethodID(sparseVectorClass, "indices", "()[I"); jmethodID sparseVectorIndices = jenv->GetMethodID(sparseVectorClass, "indices", "()[I");
jmethodID sparseVectorValues = jenv->GetMethodID(sparseVectorClass, "values", "()[D"); jmethodID sparseVectorValues = jenv->GetMethodID(sparseVectorClass, "values", "()[D");
std::vector<CSRDirect> jniCache; std::vector<CSRDirect> jniCache;
jniCache.reserve(num_rows); jniCache.reserve(num_rows);
// this needs to be done ahead of time as row_func is invoked from multiple threads // this needs to be done ahead of time as row_func is invoked from multiple threads
// these threads would have to be registered with the JVM and also unregistered. // these threads would have to be registered with the JVM and also unregistered.
// It is not clear if that can be achieved with OpenMP // It is not clear if that can be achieved with OpenMP
for (int i=0; i<num_rows; i++) { for (int i = 0; i < num_rows; i++) {
// get the row // get the row
jobject objSparseVec = jenv->GetObjectArrayElement(arrayOfSparseVector, i); jobject objSparseVec = jenv->GetObjectArrayElement(arrayOfSparseVector, i);
// get the size, indices and values // get the size, indices and values
auto indices = (jintArray)jenv->CallObjectMethod(objSparseVec, sparseVectorIndices); auto indices = (jintArray)jenv->CallObjectMethod(objSparseVec, sparseVectorIndices);
auto values = (jdoubleArray)jenv->CallObjectMethod(objSparseVec, sparseVectorValues); auto values = (jdoubleArray)jenv->CallObjectMethod(objSparseVec, sparseVectorValues);
int size = jenv->GetArrayLength(indices); int size = jenv->GetArrayLength(indices);
// Note: when testing on larger data (e.g. 288k rows per partition and 36mio rows total) // Note: when testing on larger data (e.g. 288k rows per partition and 36mio rows total)
// using GetPrimitiveArrayCritical resulted in a dead-lock // using GetPrimitiveArrayCritical resulted in a dead-lock
// lock arrays // lock arrays
...@@ -162,35 +166,35 @@ ...@@ -162,35 +166,35 @@
// in test-usecase an alternative to GetPrimitiveArrayCritical as it performs copies // in test-usecase an alternative to GetPrimitiveArrayCritical as it performs copies
int* indices0 = (int *)jenv->GetIntArrayElements(indices, 0); int* indices0 = (int *)jenv->GetIntArrayElements(indices, 0);
double* values0 = jenv->GetDoubleArrayElements(values, 0); double* values0 = jenv->GetDoubleArrayElements(values, 0);
jniCache.push_back({indices, values, indices0, values0, size}); jniCache.push_back({indices, values, indices0, values0, size});
} }
// type is important here as we want a std::function, rather than a lambda // type is important here as we want a std::function, rather than a lambda
std::function<void(int idx, std::vector<std::pair<int, double>>& ret)> row_func = [&](int row_num, std::vector<std::pair<int, double>>& ret) { std::function<void(int idx, std::vector<std::pair<int, double>>& ret)> row_func = [&](int row_num, std::vector<std::pair<int, double>>& ret) {
auto& jc = jniCache[row_num]; auto& jc = jniCache[row_num];
ret.clear(); // reset size, but not free() ret.clear(); // reset size, but not free()
ret.reserve(jc.size); // make sure we have enough allocated ret.reserve(jc.size); // make sure we have enough allocated
// copy data // copy data
int* indices0p = jc.indices0; int* indices0p = jc.indices0;
double* values0p = jc.values0; double* values0p = jc.values0;
int* indices0e = indices0p + jc.size; int* indices0e = indices0p + jc.size;
for (; indices0p != indices0e; ++indices0p, ++values0p) for (; indices0p != indices0e; ++indices0p, ++values0p)
ret.emplace_back(*indices0p, *values0p); ret.emplace_back(*indices0p, *values0p);
}; };
int ret = LGBM_DatasetCreateFromCSRFunc(&row_func, num_rows, num_col, parameters, reference, out); int ret = LGBM_DatasetCreateFromCSRFunc(&row_func, num_rows, num_col, parameters, reference, out);
for (auto& jc : jniCache) { for (auto& jc : jniCache) {
// jenv->ReleasePrimitiveArrayCritical(jc.values, jc.values0, JNI_ABORT); // jenv->ReleasePrimitiveArrayCritical(jc.values, jc.values0, JNI_ABORT);
// jenv->ReleasePrimitiveArrayCritical(jc.indices, jc.indices0, JNI_ABORT); // jenv->ReleasePrimitiveArrayCritical(jc.indices, jc.indices0, JNI_ABORT);
jenv->ReleaseDoubleArrayElements(jc.values, jc.values0, JNI_ABORT); jenv->ReleaseDoubleArrayElements(jc.values, jc.values0, JNI_ABORT);
jenv->ReleaseIntArrayElements(jc.indices, (jint *)jc.indices0, JNI_ABORT); jenv->ReleaseIntArrayElements(jc.indices, (jint *)jc.indices0, JNI_ABORT);
} }
return ret; return ret;
} }
%} %}
...@@ -224,7 +228,7 @@ ...@@ -224,7 +228,7 @@
%array_functions(char *, stringArray) %array_functions(char *, stringArray)
/* Custom pointer manipulation template */ /* Custom pointer manipulation template */
%define %pointer_manipulation(TYPE,NAME) %define %pointer_manipulation(TYPE, NAME)
%{ %{
static TYPE *new_##NAME() { %} static TYPE *new_##NAME() { %}
%{ TYPE* NAME = new TYPE; return NAME; %} %{ TYPE* NAME = new TYPE; return NAME; %}
...@@ -240,7 +244,7 @@ void delete_##NAME(TYPE *self); ...@@ -240,7 +244,7 @@ void delete_##NAME(TYPE *self);
%enddef %enddef
%define %pointer_dereference(TYPE,NAME) %define %pointer_dereference(TYPE, NAME)
%{ %{
static TYPE NAME ##_value(TYPE *self) { static TYPE NAME ##_value(TYPE *self) {
TYPE NAME = *self; TYPE NAME = *self;
...@@ -252,7 +256,7 @@ TYPE NAME##_value(TYPE *self); ...@@ -252,7 +256,7 @@ TYPE NAME##_value(TYPE *self);
%enddef %enddef
%define %pointer_handle(TYPE,NAME) %define %pointer_handle(TYPE, NAME)
%{ %{
static TYPE* NAME ##_handle() { %} static TYPE* NAME ##_handle() { %}
%{ TYPE* NAME = new TYPE; *NAME = (TYPE)operator new(sizeof(int*)); return NAME; %} %{ TYPE* NAME = new TYPE; *NAME = (TYPE)operator new(sizeof(int*)); return NAME; %}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment