c_api.cpp 71.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

49
50
51
52
53
54
55
56
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

Guolin Ke's avatar
Guolin Ke committed
57
  SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) {
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
Guolin Ke's avatar
Guolin Ke committed
74
    predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
75
                                   early_stop_, early_stop_freq_, early_stop_margin_));
Guolin Ke's avatar
Guolin Ke committed
76
    num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
77
    predict_function = predictor_->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
78
    num_total_model_ = boosting->NumberOfTotalModel();
79
80
  }
  ~SingleRowPredictor() {}
Guolin Ke's avatar
Guolin Ke committed
81
  bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) {
82
83
84
85
    return early_stop_ != config.pred_early_stop ||
      early_stop_freq_ != config.pred_early_stop_freq ||
      early_stop_margin_ != config.pred_early_stop_margin ||
      iter_ != iter ||
Guolin Ke's avatar
Guolin Ke committed
86
      num_total_model_ != boosting->NumberOfTotalModel();
87
  }
Guolin Ke's avatar
Guolin Ke committed
88

89
90
91
92
93
94
95
96
97
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
98
class Booster {
Nikita Titov's avatar
Nikita Titov committed
99
 public:
Guolin Ke's avatar
Guolin Ke committed
100
  explicit Booster(const char* filename) {
101
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
102
103
  }

Guolin Ke's avatar
Guolin Ke committed
104
  Booster(const Dataset* train_data,
105
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
106
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
107
    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
112
    if (config_.input_model.size() > 0) {
113
114
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
115
    }
Guolin Ke's avatar
Guolin Ke committed
116

Guolin Ke's avatar
Guolin Ke committed
117
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
118

119
120
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
121
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
122
    if (config_.tree_learner == std::string("feature")) {
123
      Log::Fatal("Do not support feature parallel in c api");
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
126
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
127
      config_.tree_learner = "serial";
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
130
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
131
132
133
134
135
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  }

  ~Booster() {
  }
140

141
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
142
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
143
144
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
152
153
154
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
155
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
156
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
157
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
162
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
163
164
165
166
167
168
169
170
171
172
173
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
174
175
  }

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
  static void CheckDatasetResetConfig(
      const Config& old_config,
      const std::unordered_map<std::string, std::string>& new_param) {
    Config new_config;
    new_config.Set(new_param);
    if (new_param.count("data_random_seed") &&
        new_config.data_random_seed != old_config.data_random_seed) {
      Log::Fatal("Cannot change data_random_seed after constructed Dataset handle.");
    }
    if (new_param.count("max_bin") &&
        new_config.max_bin != old_config.max_bin) {
      Log::Fatal("Cannot change max_bin after constructed Dataset handle.");
    }
    if (new_param.count("max_bin_by_feature") &&
        new_config.max_bin_by_feature != old_config.max_bin_by_feature) {
      Log::Fatal(
          "Cannot change max_bin_by_feature after constructed Dataset handle.");
    }
    if (new_param.count("bin_construct_sample_cnt") &&
        new_config.bin_construct_sample_cnt !=
            old_config.bin_construct_sample_cnt) {
      Log::Fatal(
          "Cannot change bin_construct_sample_cnt after constructed Dataset "
          "handle.");
    }
    if (new_param.count("min_data_in_bin") &&
        new_config.min_data_in_bin != old_config.min_data_in_bin) {
      Log::Fatal(
          "Cannot change min_data_in_bin after constructed Dataset handle.");
    }
    if (new_param.count("use_missing") &&
        new_config.use_missing != old_config.use_missing) {
      Log::Fatal("Cannot change use_missing after constructed Dataset handle.");
    }
    if (new_param.count("zero_as_missing") &&
        new_config.zero_as_missing != old_config.zero_as_missing) {
      Log::Fatal(
          "Cannot change zero_as_missing after constructed Dataset handle.");
    }
    if (new_param.count("categorical_feature") &&
        new_config.categorical_feature != old_config.categorical_feature) {
      Log::Fatal(
          "Cannot change categorical_feature after constructed Dataset "
          "handle.");
    }
    if (new_param.count("feature_pre_filter") &&
        new_config.feature_pre_filter != old_config.feature_pre_filter) {
      Log::Fatal(
          "Cannot change feature_pre_filter after constructed Dataset handle.");
    }
    if (new_param.count("is_enable_sparse") &&
        new_config.is_enable_sparse != old_config.is_enable_sparse) {
      Log::Fatal(
          "Cannot change is_enable_sparse after constructed Dataset handle.");
    }
    if (new_param.count("pre_partition") &&
        new_config.pre_partition != old_config.pre_partition) {
      Log::Fatal(
          "Cannot change pre_partition after constructed Dataset handle.");
    }
    if (new_param.count("enable_bundle") &&
        new_config.enable_bundle != old_config.enable_bundle) {
      Log::Fatal(
          "Cannot change enable_bundle after constructed Dataset handle.");
    }
    if (new_param.count("header") && new_config.header != old_config.header) {
      Log::Fatal("Cannot change header after constructed Dataset handle.");
    }
    if (new_param.count("two_round") &&
        new_config.two_round != old_config.two_round) {
      Log::Fatal("Cannot change two_round after constructed Dataset handle.");
    }
    if (new_param.count("label_column") &&
        new_config.label_column != old_config.label_column) {
      Log::Fatal(
          "Cannot change label_column after constructed Dataset handle.");
    }
    if (new_param.count("weight_column") &&
        new_config.weight_column != old_config.weight_column) {
      Log::Fatal(
          "Cannot change weight_column after constructed Dataset handle.");
    }
    if (new_param.count("group_column") &&
        new_config.group_column != old_config.group_column) {
      Log::Fatal(
          "Cannot change group_column after constructed Dataset handle.");
    }
    if (new_param.count("ignore_column") &&
        new_config.ignore_column != old_config.ignore_column) {
      Log::Fatal(
          "Cannot change ignore_column after constructed Dataset handle.");
    }
    if (new_param.count("forcedbins_filename")) {
      Log::Fatal("Cannot change forced bins after constructed Dataset handle.");
    }
    if (new_param.count("min_data_in_leaf") &&
        new_config.min_data_in_leaf < old_config.min_data_in_leaf &&
        old_config.feature_pre_filter) {
      Log::Fatal(
          "Reducing `min_data_in_leaf` with `feature_pre_filter=true` may "
          "cause unexpected behaviour "
          "for features that were pre-filtered by the larger "
          "`min_data_in_leaf`.\n"
          "You need to set `feature_pre_filter=false` to dynamically change "
          "the `min_data_in_leaf`.");
    }
  }

wxchan's avatar
wxchan committed
284
  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
285
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
286
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
287
    if (param.count("num_class")) {
288
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
289
    }
Guolin Ke's avatar
Guolin Ke committed
290
291
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
292
    }
Guolin Ke's avatar
Guolin Ke committed
293
    if (param.count("metric")) {
294
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
295
    }
Guolin Ke's avatar
Guolin Ke committed
296

297
298
    CheckDatasetResetConfig(config_, param);

Guolin Ke's avatar
Guolin Ke committed
299
    config_.Set(param);
300

301
302
303
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
304
305
306

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
307
308
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
315
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
316
317
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
318
    }
Guolin Ke's avatar
Guolin Ke committed
319

Guolin Ke's avatar
Guolin Ke committed
320
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
321
322
323
324
325
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
326
327
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
328
329
330
331
332
333
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
334
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
335
  }
Guolin Ke's avatar
Guolin Ke committed
336

337
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
338
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
339
    return boosting_->TrainOneIter(nullptr, nullptr);
340
341
  }

Guolin Ke's avatar
Guolin Ke committed
342
343
344
345
346
347
348
349
350
351
352
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

353
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
354
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
355
    return boosting_->TrainOneIter(gradients, hessians);
356
357
  }

wxchan's avatar
wxchan committed
358
359
360
361
362
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

363
  void PredictSingleRow(int num_iteration, int predict_type, int ncol,
364
365
366
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
367
368
369
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
370
    }
371
    std::lock_guard<std::mutex> lock(mutex_);
372
    if (single_row_predictor_[predict_type].get() == nullptr ||
Guolin Ke's avatar
Guolin Ke committed
373
374
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
375
                                                                       config, num_iteration));
376
377
378
    }
    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
379
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
380

381
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
382
383
384
  }


385
  void Predict(int num_iteration, int predict_type, int nrow, int ncol,
Guolin Ke's avatar
Guolin Ke committed
386
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
387
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
388
               double* out_result, int64_t* out_len) {
389
390
391
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
392
    }
wxchan's avatar
wxchan committed
393
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
394
395
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
396
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
397
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
398
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
399
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
400
      is_raw_score = true;
401
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
402
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
403
404
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
405
    }
Guolin Ke's avatar
Guolin Ke committed
406

Guolin Ke's avatar
Guolin Ke committed
407
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
408
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
409
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
410
    auto pred_fun = predictor.GetPredictFunction();
411
412
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
413
    for (int i = 0; i < nrow; ++i) {
414
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
415
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
416
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
417
      pred_fun(one_row, pred_wrt_ptr);
418
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
419
    }
420
    OMP_THROW_EX();
421
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
422
423
424
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
425
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
426
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
427
428
429
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
430
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
431
432
433
434
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
435
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
436
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
437
438
439
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
440
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
441
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
442
    bool bool_data_has_header = data_has_header > 0 ? true : false;
443
    predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
Guolin Ke's avatar
Guolin Ke committed
444
445
  }

Guolin Ke's avatar
Guolin Ke committed
446
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
447
448
449
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

450
451
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
452
  }
453

454
  void LoadModelFromString(const char* model_str) {
455
456
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
457
458
  }

459
460
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
461
462
  }

463
  std::string DumpModel(int start_iteration, int num_iteration) {
464
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
465
  }
466

467
468
469
470
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
471
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
472
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
473
474
475
476
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
477
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
478
479
  }

480
  void ShuffleModels(int start_iter, int end_iter) {
481
    std::lock_guard<std::mutex> lock(mutex_);
482
    boosting_->ShuffleModels(start_iter, end_iter);
483
484
  }

wxchan's avatar
wxchan committed
485
486
487
488
489
490
491
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
492

wxchan's avatar
wxchan committed
493
494
495
496
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
497
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
498
499
500
501
502
503
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
504
505
506
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
507
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
508
509
510
511
512
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
513
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
514

Nikita Titov's avatar
Nikita Titov committed
515
 private:
wxchan's avatar
wxchan committed
516
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
517
  std::unique_ptr<Boosting> boosting_;
518
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
519

Guolin Ke's avatar
Guolin Ke committed
520
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
521
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
522
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
523
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
524
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
525
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
526
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
527
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
528
529
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
530
531
};

532
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
533
534
535

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
536
537
538
539
540
541
542
543
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

544
545
546
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
547
548
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
549
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
550
551
552

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
553
 public:
Guolin Ke's avatar
Guolin Ke committed
554
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
555
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
556
557
558
559
560
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
561
562

 private:
Guolin Ke's avatar
Guolin Ke committed
563
564
565
566
567
568
569
570
571
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
572
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
573
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
574
575
}

Guolin Ke's avatar
Guolin Ke committed
576
int LGBM_DatasetCreateFromFile(const char* filename,
577
578
579
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
580
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
581
582
  auto param = Config::Str2Map(parameters);
  Config config;
583
584
585
586
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
587
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
588
  if (reference == nullptr) {
589
590
591
592
593
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
594
  } else {
595
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
596
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
597
  }
598
  API_END();
Guolin Ke's avatar
Guolin Ke committed
599
600
}

601

Guolin Ke's avatar
Guolin Ke committed
602
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
603
604
605
606
607
608
609
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
610
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
611
612
  auto param = Config::Str2Map(parameters);
  Config config;
613
614
615
616
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
617
  DatasetLoader loader(config, nullptr, 1, nullptr);
618
619
620
621
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
622
623
}

624

Guolin Ke's avatar
Guolin Ke committed
625
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
626
627
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
628
629
630
631
632
633
634
635
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
636
int LGBM_DatasetPushRows(DatasetHandle dataset,
637
638
639
640
641
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
642
643
644
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
645
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
646
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
647
  for (int i = 0; i < nrow; ++i) {
648
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
649
650
651
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
652
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
653
  }
654
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
655
656
657
658
659
660
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
661
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
662
663
664
665
666
667
668
669
670
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
671
672
673
674
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
675
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
676
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
677
  for (int i = 0; i < nrow; ++i) {
678
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
679
680
681
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
682
                          static_cast<data_size_t>(start_row + i), one_row);
683
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
684
  }
685
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
686
687
688
689
690
691
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
692
int LGBM_DatasetCreateFromMat(const void* data,
693
694
695
696
697
698
699
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
721
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
722
723
  auto param = Config::Str2Map(parameters);
  Config config;
724
725
726
727
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
728
  std::unique_ptr<Dataset> ret;
729
730
731
732
733
734
735
736
737
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
738

Guolin Ke's avatar
Guolin Ke committed
739
740
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
741
    Random rand(config.data_random_seed);
742
743
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
744
    sample_cnt = static_cast<int>(sample_indices.size());
745
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
746
    std::vector<std::vector<int>> sample_idx(ncol);
747
748
749

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
750
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
751
      auto idx = sample_indices[i];
752
753
754
755
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
756

757
758
759
760
761
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
762
        }
Guolin Ke's avatar
Guolin Ke committed
763
764
      }
    }
Guolin Ke's avatar
Guolin Ke committed
765
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
766
767
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
768
                                            ncol,
769
                                            Common::VectorSize<double>(sample_values).data(),
770
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
771
  } else {
772
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
773
    ret->CreateValid(
774
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
775
  }
776
777
778
779
780
781
782
783
784
785
786
787
788
789
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
790
791
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
792
  *out = ret.release();
793
  API_END();
794
795
}

Guolin Ke's avatar
Guolin Ke committed
796
int LGBM_DatasetCreateFromCSR(const void* indptr,
797
798
799
800
801
802
803
804
805
806
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
807
  API_BEGIN();
808
809
810
811
812
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
813
814
  auto param = Config::Str2Map(parameters);
  Config config;
815
816
817
818
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
819
  std::unique_ptr<Dataset> ret;
820
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
821
822
823
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
824
825
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
826
    auto sample_indices = rand.Sample(nrow, sample_cnt);
827
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
828
829
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
830
831
832
833
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
834
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
835
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
836
837
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
838
839
840
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
841
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
842
843
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
844
                                            static_cast<int>(num_col),
845
846
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
847
  } else {
848
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
849
    ret->CreateValid(
850
      reinterpret_cast<const Dataset*>(reference));
851
  }
852
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
853
  #pragma omp parallel for schedule(static)
854
  for (int i = 0; i < nindptr - 1; ++i) {
855
    OMP_LOOP_EX_BEGIN();
856
857
858
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
859
    OMP_LOOP_EX_END();
860
  }
861
  OMP_THROW_EX();
862
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
863
  *out = ret.release();
864
  API_END();
865
866
}

867
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
868
869
870
871
872
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
873
  API_BEGIN();
874
875
876
877
878
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
910
911
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
912
                                            static_cast<int>(num_col),
913
914
915
916
917
918
919
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
920

921
922
923
924
925
926
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
927
928
929
      const int tid = omp_get_thread_num();
      get_row_fun(i, threadBuffer);
      ret->PushOneRow(tid, i, threadBuffer);
930
931
932
933
934
935
936
937
938
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
939
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
940
941
942
943
944
945
946
947
948
949
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
950
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
951
952
  auto param = Config::Str2Map(parameters);
  Config config;
953
954
955
956
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
957
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
958
959
960
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
961
962
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
963
    auto sample_indices = rand.Sample(nrow, sample_cnt);
964
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
965
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
966
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
967
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
968
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
969
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
970
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
971
972
973
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
974
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
975
976
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
977
978
        }
      }
979
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
980
    }
981
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
982
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
983
984
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
985
986
987
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
988
  } else {
989
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
990
    ret->CreateValid(
991
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
992
  }
993
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
994
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
995
  for (int i = 0; i < ncol_ptr - 1; ++i) {
996
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
997
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
998
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
999
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1000
1001
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
1002
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
Guolin Ke's avatar
Guolin Ke committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
    auto bin_mapper = ret->FeatureBinMapper(feature_idx);
    if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
      int row_idx = 0;
      while (row_idx < nrow) {
        auto pair = col_it.NextNonZero();
        row_idx = pair.first;
        // no more data
        if (row_idx < 0) { break; }
        ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
      }
    } else {
      for (int row_idx = 0; row_idx < nrow; ++row_idx) {
        auto val = col_it.Get(row_idx);
        ret->PushOneData(tid, row_idx, group, sub_feature, val);
      }
Guolin Ke's avatar
Guolin Ke committed
1018
    }
1019
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1020
  }
1021
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1022
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1023
  *out = ret.release();
1024
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1025
1026
}

Guolin Ke's avatar
Guolin Ke committed
1027
int LGBM_DatasetGetSubset(
1028
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
1029
1030
1031
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
1032
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
1033
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1034
1035
  auto param = Config::Str2Map(parameters);
  Config config;
1036
1037
1038
1039
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
1040
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1041
  CHECK(num_used_row_indices > 0);
1042
1043
1044
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
1045
1046
1047
  if (!std::is_sorted(used_row_indices, used_row_indices + num_used_row_indices)) {
    Log::Fatal("used_row_indices should be sorted in Subset");
  }
Guolin Ke's avatar
Guolin Ke committed
1048
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
1049
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
1050
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
1051
1052
1053
1054
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1055
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
1056
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
1057
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
1058
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
1059
1060
1061
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
1062
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1063
1064
1065
1066
1067
1068
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1069
int LGBM_DatasetGetFeatureNames(
1070
1071
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
1072
  int* num_feature_names) {
1073
1074
1075
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
1076
1077
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1078
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
1079
1080
1081
1082
  }
  API_END();
}

1083
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1084
int LGBM_DatasetFree(DatasetHandle handle) {
1085
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1086
  delete reinterpret_cast<Dataset*>(handle);
1087
  API_END();
1088
1089
}

Guolin Ke's avatar
Guolin Ke committed
1090
int LGBM_DatasetSaveBinary(DatasetHandle handle,
1091
                           const char* filename) {
1092
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1093
1094
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
1095
  API_END();
1096
1097
}

1098
1099
1100
1101
1102
1103
1104
1105
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1106
int LGBM_DatasetSetField(DatasetHandle handle,
1107
1108
1109
1110
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
1111
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1112
  auto dataset = reinterpret_cast<Dataset*>(handle);
1113
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1114
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
1115
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1116
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
1117
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1118
1119
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
1120
  }
1121
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
1122
  API_END();
1123
1124
}

Guolin Ke's avatar
Guolin Ke committed
1125
int LGBM_DatasetGetField(DatasetHandle handle,
1126
1127
1128
1129
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
1130
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1131
  auto dataset = reinterpret_cast<Dataset*>(handle);
1132
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1133
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1134
    *out_type = C_API_DTYPE_FLOAT32;
1135
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1136
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1137
    *out_type = C_API_DTYPE_INT32;
1138
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
1142
  } 
1143
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
1144
  if (*out_ptr == nullptr) { *out_len = 0; }
1145
  API_END();
1146
1147
}

1148
int LGBM_DatasetUpdateParamChecking(const char* old_parameters, const char* new_parameters) {
1149
  API_BEGIN();
1150
1151
1152
1153
1154
  auto old_param = Config::Str2Map(old_parameters);
  Config old_config;
  old_config.Set(old_param);
  auto new_param = Config::Str2Map(new_parameters);
  Booster::CheckDatasetResetConfig(old_config, new_param);
1155
1156
1157
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1158
int LGBM_DatasetGetNumData(DatasetHandle handle,
1159
                           int* out) {
1160
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1161
1162
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1163
  API_END();
1164
1165
}

Guolin Ke's avatar
Guolin Ke committed
1166
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1167
                              int* out) {
1168
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1169
1170
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1171
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1172
}
1173

1174
1175
1176
1177
1178
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
1179
  target_d->AddFeaturesFrom(source_d);
1180
1181
1182
  API_END();
}

1183
1184
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1185
int LGBM_BoosterCreate(const DatasetHandle train_data,
1186
1187
                       const char* parameters,
                       BoosterHandle* out) {
1188
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1189
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1190
1191
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1192
  API_END();
1193
1194
}

Guolin Ke's avatar
Guolin Ke committed
1195
int LGBM_BoosterCreateFromModelfile(
1196
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1197
  int* out_num_iterations,
1198
  BoosterHandle* out) {
1199
  API_BEGIN();
wxchan's avatar
wxchan committed
1200
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1201
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1202
  *out = ret.release();
1203
  API_END();
1204
1205
}

Guolin Ke's avatar
Guolin Ke committed
1206
int LGBM_BoosterLoadModelFromString(
1207
1208
1209
1210
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1211
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1212
1213
1214
1215
1216
1217
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1218
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1219
int LGBM_BoosterFree(BoosterHandle handle) {
1220
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1221
  delete reinterpret_cast<Booster*>(handle);
1222
  API_END();
1223
1224
}

1225
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1226
1227
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1228
  ref_booster->ShuffleModels(start_iter, end_iter);
1229
1230
1231
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1232
int LGBM_BoosterMerge(BoosterHandle handle,
1233
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1234
1235
1236
1237
1238
1239
1240
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1241
int LGBM_BoosterAddValidData(BoosterHandle handle,
1242
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1243
1244
1245
1246
1247
1248
1249
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1250
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1251
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1252
1253
1254
1255
1256
1257
1258
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1259
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1260
1261
1262
1263
1264
1265
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1266
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1267
1268
1269
1270
1271
1272
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1273
1274
1275
1276
1277
1278
1279
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1280
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1281
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1282
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1283
1284
1285
1286
1287
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1288
  API_END();
1289
1290
}

Guolin Ke's avatar
Guolin Ke committed
1291
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1292
1293
1294
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1295
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1296
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1297
  #ifdef SCORE_T_USE_DOUBLE
1298
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1299
  #else
1300
1301
1302
1303
1304
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1305
  #endif
1306
  API_END();
1307
1308
}

Guolin Ke's avatar
Guolin Ke committed
1309
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1310
1311
1312
1313
1314
1315
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1316
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1317
1318
1319
1320
1321
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1322

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1337
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1338
1339
1340
1341
1342
1343
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1344
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1345
1346
1347
1348
1349
1350
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1351
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1352
1353
1354
1355
1356
1357
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1358
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1359
1360
1361
1362
1363
1364
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1365
int LGBM_BoosterGetEval(BoosterHandle handle,
1366
1367
1368
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1369
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1370
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1371
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1372
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1373
  *out_len = static_cast<int>(result_buf.size());
1374
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1375
    (out_results)[i] = static_cast<double>(result_buf[i]);
1376
  }
1377
  API_END();
1378
1379
}

Guolin Ke's avatar
Guolin Ke committed
1380
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1381
1382
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1383
1384
1385
1386
1387
1388
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1389
int LGBM_BoosterGetPredict(BoosterHandle handle,
1390
1391
1392
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1393
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1394
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1395
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1396
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1397
1398
}

Guolin Ke's avatar
Guolin Ke committed
1399
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1400
1401
1402
1403
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1404
                               const char* parameter,
1405
                               const char* result_filename) {
1406
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1407
1408
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1409
1410
1411
1412
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1413
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1414
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1415
                       config, result_filename);
1416
  API_END();
1417
1418
}

Guolin Ke's avatar
Guolin Ke committed
1419
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1420
1421
1422
1423
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1424
1425
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1426
1427
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1428
1429
1430
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1431
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1432
1433
1434
1435
1436
1437
1438
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
1439
                              int64_t num_col,
1440
1441
                              int predict_type,
                              int num_iteration,
1442
                              const char* parameter,
1443
1444
                              int64_t* out_len,
                              double* out_result) {
1445
  API_BEGIN();
1446
1447
1448
1449
1450
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1451
1452
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1453
1454
1455
1456
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1457
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1458
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1459
  int nrow = static_cast<int>(nindptr - 1);
1460
  ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1461
                       config, out_result, out_len);
1462
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1463
}
1464

1465
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1466
1467
1468
1469
1470
1471
1472
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
1473
                                       int64_t num_col,
1474
1475
1476
1477
1478
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1479
  API_BEGIN();
1480
1481
1482
1483
1484
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1485
1486
1487
1488
1489
1490
1491
1492
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1493
  ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
1494
1495
1496
1497
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1498
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1509
                              const char* parameter,
1510
1511
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1512
1513
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1514
1515
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1526
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1527
1528
1529
1530
1531
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1532
1533
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1534
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1535
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1536
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1537
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1538
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1539
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1540
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1541
1542
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1543
1544
    return one_row;
  };
1545
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
cbecker's avatar
cbecker committed
1546
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1547
1548
1549
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1550
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1551
1552
1553
1554
1555
1556
1557
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1558
                              const char* parameter,
1559
1560
                              int64_t* out_len,
                              double* out_result) {
1561
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1562
1563
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1564
1565
1566
1567
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1568
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1569
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1570
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1571
                       config, out_result, out_len);
1572
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1573
}
1574

1575
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1576
1577
1578
1579
1580
1581
1582
1583
1584
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1585
1586
1587
1588
1589
1590
1591
1592
1593
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1594
  ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
1595
1596
1597
1598
  API_END();
}


1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1618
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
1619
1620
1621
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1622
int LGBM_BoosterSaveModel(BoosterHandle handle,
1623
                          int start_iteration,
1624
1625
                          int num_iteration,
                          const char* filename) {
1626
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1627
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1628
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1629
1630
1631
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1632
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1633
                                  int start_iteration,
1634
                                  int num_iteration,
1635
                                  int64_t buffer_len,
1636
                                  int64_t* out_len,
1637
                                  char* out_str) {
1638
1639
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1640
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1641
  *out_len = static_cast<int64_t>(model.size()) + 1;
1642
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1643
    std::memcpy(out_str, model.c_str(), *out_len);
1644
1645
1646
1647
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1648
int LGBM_BoosterDumpModel(BoosterHandle handle,
1649
                          int start_iteration,
1650
                          int num_iteration,
1651
1652
                          int64_t buffer_len,
                          int64_t* out_len,
1653
                          char* out_str) {
wxchan's avatar
wxchan committed
1654
1655
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1656
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1657
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1658
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1659
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1660
  }
1661
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1662
}
1663

Guolin Ke's avatar
Guolin Ke committed
1664
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1665
1666
1667
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1668
1669
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1670
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1671
1672
1673
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1674
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1675
1676
1677
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1678
1679
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1680
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1681
1682
1683
  API_END();
}

1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1697
1698
1699
1700
1701
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1702
  Config config;
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1719
1720
1721
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1722
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1723
  if (num_machines > 1) {
1724
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1725
1726
1727
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1728

Guolin Ke's avatar
Guolin Ke committed
1729
// ---- start of some help functions
1730
1731
1732

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1733
  if (data_type == C_API_DTYPE_FLOAT32) {
1734
1735
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1736
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1737
        std::vector<double> ret(num_col);
1738
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1739
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1740
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1741
1742
1743
1744
        }
        return ret;
      };
    } else {
1745
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1746
        std::vector<double> ret(num_col);
1747
        for (int i = 0; i < num_col; ++i) {
1748
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1749
1750
1751
1752
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1753
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1754
1755
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1756
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1757
        std::vector<double> ret(num_col);
1758
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1759
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1760
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1761
1762
1763
1764
        }
        return ret;
      };
    } else {
1765
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1766
        std::vector<double> ret(num_col);
1767
        for (int i = 0; i < num_col; ++i) {
1768
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1769
1770
1771
1772
1773
        }
        return ret;
      };
    }
  }
1774
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1775
1776
1777
1778
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1779
1780
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1781
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1782
1783
1784
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1785
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1786
          ret.emplace_back(i, raw_values[i]);
1787
        }
Guolin Ke's avatar
Guolin Ke committed
1788
1789
1790
      }
      return ret;
    };
1791
  }
Guolin Ke's avatar
Guolin Ke committed
1792
  return nullptr;
1793
1794
}

1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1811
std::function<std::vector<std::pair<int, double>>(int idx)>
1812
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1813
  if (data_type == C_API_DTYPE_FLOAT32) {
1814
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1815
    if (indptr_type == C_API_DTYPE_INT32) {
1816
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1817
      return [=] (int idx) {
1818
1819
1820
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1821
1822
1823
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1824
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1825
          ret.emplace_back(indices[i], data_ptr[i]);
1826
1827
1828
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1829
    } else if (indptr_type == C_API_DTYPE_INT64) {
1830
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1831
      return [=] (int idx) {
1832
1833
1834
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1835
1836
1837
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1838
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1839
          ret.emplace_back(indices[i], data_ptr[i]);
1840
1841
1842
1843
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1844
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1845
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1846
    if (indptr_type == C_API_DTYPE_INT32) {
1847
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1848
      return [=] (int idx) {
1849
1850
1851
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1852
1853
1854
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1855
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1856
          ret.emplace_back(indices[i], data_ptr[i]);
1857
1858
1859
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1860
    } else if (indptr_type == C_API_DTYPE_INT64) {
1861
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1862
      return [=] (int idx) {
1863
1864
1865
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1866
1867
1868
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1869
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1870
          ret.emplace_back(indices[i], data_ptr[i]);
1871
1872
1873
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1874
1875
    }
  }
1876
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1877
1878
}

Guolin Ke's avatar
Guolin Ke committed
1879
std::function<std::pair<int, double>(int idx)>
1880
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1881
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1882
  if (data_type == C_API_DTYPE_FLOAT32) {
1883
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1884
    if (col_ptr_type == C_API_DTYPE_INT32) {
1885
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1886
1887
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1888
1889
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1890
1891
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1892
        }
Guolin Ke's avatar
Guolin Ke committed
1893
1894
1895
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1896
      };
Guolin Ke's avatar
Guolin Ke committed
1897
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1898
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1899
1900
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1901
1902
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1903
1904
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1905
        }
Guolin Ke's avatar
Guolin Ke committed
1906
1907
1908
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1909
      };
Guolin Ke's avatar
Guolin Ke committed
1910
    }
Guolin Ke's avatar
Guolin Ke committed
1911
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1912
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1913
    if (col_ptr_type == C_API_DTYPE_INT32) {
1914
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1915
1916
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1917
1918
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1919
1920
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1921
        }
Guolin Ke's avatar
Guolin Ke committed
1922
1923
1924
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1925
      };
Guolin Ke's avatar
Guolin Ke committed
1926
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1927
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1928
1929
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1930
1931
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1932
1933
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1934
        }
Guolin Ke's avatar
Guolin Ke committed
1935
1936
1937
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1938
      };
Guolin Ke's avatar
Guolin Ke committed
1939
1940
    }
  }
1941
  throw std::runtime_error("Unknown data type in CSC matrix");
1942
1943
}

Guolin Ke's avatar
Guolin Ke committed
1944
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1945
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1946
1947
1948
1949
1950
1951
1952
1953
1954
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1955
    }
Guolin Ke's avatar
Guolin Ke committed
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1973
    }
Guolin Ke's avatar
Guolin Ke committed
1974
1975
1976
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1977
  }
Guolin Ke's avatar
Guolin Ke committed
1978
}