c_api.cpp 74.1 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
16
#include <LightGBM/utils/locale_context.h>
17
18
19
20
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
21
22

#include <string>
23
24
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
25
#include <memory>
wxchan's avatar
wxchan committed
26
#include <mutex>
27
28
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
29

30
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
31

Guolin Ke's avatar
Guolin Ke committed
32
33
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

50
51
52
53
54
55
56
57
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

Guolin Ke's avatar
Guolin Ke committed
58
  SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) {
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
Guolin Ke's avatar
Guolin Ke committed
75
    predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
76
                                   early_stop_, early_stop_freq_, early_stop_margin_));
Guolin Ke's avatar
Guolin Ke committed
77
    num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
78
    predict_function = predictor_->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
79
    num_total_model_ = boosting->NumberOfTotalModel();
80
  }
81

82
  ~SingleRowPredictor() {}
83

Guolin Ke's avatar
Guolin Ke committed
84
  bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) {
85
86
87
88
89
    return early_stop_ == config.pred_early_stop &&
      early_stop_freq_ == config.pred_early_stop_freq &&
      early_stop_margin_ == config.pred_early_stop_margin &&
      iter_ == iter &&
      num_total_model_ == boosting->NumberOfTotalModel();
90
  }
Guolin Ke's avatar
Guolin Ke committed
91

92
93
94
95
96
97
98
99
100
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
101
class Booster {
Nikita Titov's avatar
Nikita Titov committed
102
 public:
Guolin Ke's avatar
Guolin Ke committed
103
  explicit Booster(const char* filename) {
104
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
105
106
  }

Guolin Ke's avatar
Guolin Ke committed
107
  Booster(const Dataset* train_data,
108
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
109
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
110
    config_.Set(param);
111
112
113
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
114
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
115
    if (config_.input_model.size() > 0) {
116
117
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
118
    }
Guolin Ke's avatar
Guolin Ke committed
119

Guolin Ke's avatar
Guolin Ke committed
120
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
121

122
123
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
124
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
125
    if (config_.tree_learner == std::string("feature")) {
126
      Log::Fatal("Do not support feature parallel in c api");
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
129
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
130
      config_.tree_learner = "serial";
131
    }
Guolin Ke's avatar
Guolin Ke committed
132
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
133
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
134
135
136
137
138
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
139
140
141
142
  }

  ~Booster() {
  }
143

144
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
145
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
146
147
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152
153
154
155
156
157
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
158
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
159
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
160
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
166
167
168
169
170
171
172
173
174
175
176
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
177
178
  }

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
  static void CheckDatasetResetConfig(
      const Config& old_config,
      const std::unordered_map<std::string, std::string>& new_param) {
    Config new_config;
    new_config.Set(new_param);
    if (new_param.count("data_random_seed") &&
        new_config.data_random_seed != old_config.data_random_seed) {
      Log::Fatal("Cannot change data_random_seed after constructed Dataset handle.");
    }
    if (new_param.count("max_bin") &&
        new_config.max_bin != old_config.max_bin) {
      Log::Fatal("Cannot change max_bin after constructed Dataset handle.");
    }
    if (new_param.count("max_bin_by_feature") &&
        new_config.max_bin_by_feature != old_config.max_bin_by_feature) {
      Log::Fatal(
          "Cannot change max_bin_by_feature after constructed Dataset handle.");
    }
    if (new_param.count("bin_construct_sample_cnt") &&
        new_config.bin_construct_sample_cnt !=
            old_config.bin_construct_sample_cnt) {
      Log::Fatal(
          "Cannot change bin_construct_sample_cnt after constructed Dataset "
          "handle.");
    }
    if (new_param.count("min_data_in_bin") &&
        new_config.min_data_in_bin != old_config.min_data_in_bin) {
      Log::Fatal(
          "Cannot change min_data_in_bin after constructed Dataset handle.");
    }
    if (new_param.count("use_missing") &&
        new_config.use_missing != old_config.use_missing) {
      Log::Fatal("Cannot change use_missing after constructed Dataset handle.");
    }
    if (new_param.count("zero_as_missing") &&
        new_config.zero_as_missing != old_config.zero_as_missing) {
      Log::Fatal(
          "Cannot change zero_as_missing after constructed Dataset handle.");
    }
    if (new_param.count("categorical_feature") &&
        new_config.categorical_feature != old_config.categorical_feature) {
      Log::Fatal(
          "Cannot change categorical_feature after constructed Dataset "
          "handle.");
    }
    if (new_param.count("feature_pre_filter") &&
        new_config.feature_pre_filter != old_config.feature_pre_filter) {
      Log::Fatal(
          "Cannot change feature_pre_filter after constructed Dataset handle.");
    }
    if (new_param.count("is_enable_sparse") &&
        new_config.is_enable_sparse != old_config.is_enable_sparse) {
      Log::Fatal(
          "Cannot change is_enable_sparse after constructed Dataset handle.");
    }
    if (new_param.count("pre_partition") &&
        new_config.pre_partition != old_config.pre_partition) {
      Log::Fatal(
          "Cannot change pre_partition after constructed Dataset handle.");
    }
    if (new_param.count("enable_bundle") &&
        new_config.enable_bundle != old_config.enable_bundle) {
      Log::Fatal(
          "Cannot change enable_bundle after constructed Dataset handle.");
    }
    if (new_param.count("header") && new_config.header != old_config.header) {
      Log::Fatal("Cannot change header after constructed Dataset handle.");
    }
    if (new_param.count("two_round") &&
        new_config.two_round != old_config.two_round) {
      Log::Fatal("Cannot change two_round after constructed Dataset handle.");
    }
    if (new_param.count("label_column") &&
        new_config.label_column != old_config.label_column) {
      Log::Fatal(
          "Cannot change label_column after constructed Dataset handle.");
    }
    if (new_param.count("weight_column") &&
        new_config.weight_column != old_config.weight_column) {
      Log::Fatal(
          "Cannot change weight_column after constructed Dataset handle.");
    }
    if (new_param.count("group_column") &&
        new_config.group_column != old_config.group_column) {
      Log::Fatal(
          "Cannot change group_column after constructed Dataset handle.");
    }
    if (new_param.count("ignore_column") &&
        new_config.ignore_column != old_config.ignore_column) {
      Log::Fatal(
          "Cannot change ignore_column after constructed Dataset handle.");
    }
    if (new_param.count("forcedbins_filename")) {
      Log::Fatal("Cannot change forced bins after constructed Dataset handle.");
    }
    if (new_param.count("min_data_in_leaf") &&
        new_config.min_data_in_leaf < old_config.min_data_in_leaf &&
        old_config.feature_pre_filter) {
      Log::Fatal(
          "Reducing `min_data_in_leaf` with `feature_pre_filter=true` may "
          "cause unexpected behaviour "
          "for features that were pre-filtered by the larger "
          "`min_data_in_leaf`.\n"
          "You need to set `feature_pre_filter=false` to dynamically change "
          "the `min_data_in_leaf`.");
    }
  }

wxchan's avatar
wxchan committed
287
  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
288
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
289
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
290
    if (param.count("num_class")) {
291
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
292
    }
Guolin Ke's avatar
Guolin Ke committed
293
294
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
295
    }
Guolin Ke's avatar
Guolin Ke committed
296
    if (param.count("metric")) {
297
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
298
    }
299
300
    CheckDatasetResetConfig(config_, param);

Guolin Ke's avatar
Guolin Ke committed
301
    config_.Set(param);
302

303
304
305
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
306
307
308

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
309
310
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
311
312
313
314
315
316
317
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
318
319
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
320
    }
Guolin Ke's avatar
Guolin Ke committed
321

Guolin Ke's avatar
Guolin Ke committed
322
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
323
324
325
326
327
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
328
329
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
330
331
332
333
334
335
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
336
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
337
  }
Guolin Ke's avatar
Guolin Ke committed
338

339
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
340
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
341
    return boosting_->TrainOneIter(nullptr, nullptr);
342
343
  }

Guolin Ke's avatar
Guolin Ke committed
344
345
346
347
348
349
350
351
352
353
354
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

355
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
356
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
357
    return boosting_->TrainOneIter(gradients, hessians);
358
359
  }

wxchan's avatar
wxchan committed
360
361
362
363
364
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

365
  void PredictSingleRow(int num_iteration, int predict_type, int ncol,
366
367
368
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
369
370
371
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
372
    }
373
    std::lock_guard<std::mutex> lock(mutex_);
374
    if (single_row_predictor_[predict_type].get() == nullptr ||
Guolin Ke's avatar
Guolin Ke committed
375
376
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
377
                                                                       config, num_iteration));
378
379
380
    }
    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
381
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
382

383
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
384
385
386
  }


387
  void Predict(int num_iteration, int predict_type, int nrow, int ncol,
Guolin Ke's avatar
Guolin Ke committed
388
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
389
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
390
               double* out_result, int64_t* out_len) {
391
392
393
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
394
    }
wxchan's avatar
wxchan committed
395
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
396
397
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
398
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
399
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
400
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
401
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
402
      is_raw_score = true;
403
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
404
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
405
406
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
407
    }
Guolin Ke's avatar
Guolin Ke committed
408

Guolin Ke's avatar
Guolin Ke committed
409
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
410
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
411
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
412
    auto pred_fun = predictor.GetPredictFunction();
413
414
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
415
    for (int i = 0; i < nrow; ++i) {
416
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
417
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
418
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
419
      pred_fun(one_row, pred_wrt_ptr);
420
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
421
    }
422
    OMP_THROW_EX();
423
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
424
425
426
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
427
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
428
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
429
430
431
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
432
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
437
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
438
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
439
440
441
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
442
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
443
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
444
    bool bool_data_has_header = data_has_header > 0 ? true : false;
445
    predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
Guolin Ke's avatar
Guolin Ke committed
446
447
  }

Guolin Ke's avatar
Guolin Ke committed
448
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
449
450
451
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

452
453
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
454
  }
455

456
  void LoadModelFromString(const char* model_str) {
457
458
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
459
460
  }

461
462
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
463
464
  }

465
  std::string DumpModel(int start_iteration, int num_iteration) {
466
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
467
  }
468

469
470
471
472
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

473
474
475
476
477
478
479
480
481
482
  double UpperBoundValue() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->GetUpperBoundValue();
  }

  double LowerBoundValue() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->GetLowerBoundValue();
  }

Guolin Ke's avatar
Guolin Ke committed
483
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
484
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
485
486
487
488
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
489
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
490
491
  }

492
  void ShuffleModels(int start_iter, int end_iter) {
493
    std::lock_guard<std::mutex> lock(mutex_);
494
    boosting_->ShuffleModels(start_iter, end_iter);
495
496
  }

wxchan's avatar
wxchan committed
497
498
499
500
501
502
503
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
504

505
506
  int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
    *out_buffer_len = 0;
wxchan's avatar
wxchan committed
507
508
509
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
510
511
512
513
514
        if (idx < len) {
          std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len));
          out_strs[idx][buffer_len - 1] = '\0';
        }
        *out_buffer_len = std::max(name.size() + 1, *out_buffer_len);
wxchan's avatar
wxchan committed
515
516
517
518
519
520
        ++idx;
      }
    }
    return idx;
  }

521
522
  int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
    *out_buffer_len = 0;
wxchan's avatar
wxchan committed
523
524
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
525
526
527
528
529
      if (idx < len) {
        std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len));
        out_strs[idx][buffer_len - 1] = '\0';
      }
      *out_buffer_len = std::max(name.size() + 1, *out_buffer_len);
wxchan's avatar
wxchan committed
530
531
532
533
534
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
535
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
536

Nikita Titov's avatar
Nikita Titov committed
537
 private:
wxchan's avatar
wxchan committed
538
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
539
  std::unique_ptr<Boosting> boosting_;
540
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
541

Guolin Ke's avatar
Guolin Ke committed
542
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
543
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
544
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
545
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
546
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
547
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
548
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
549
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
550
  /*! \brief mutex for threading safe call */
551
  mutable std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
552
553
};

554
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
555
556
557

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
558
559
560
561
562
563
564
565
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

566
567
568
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
569
570
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
571
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
572
573
574

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
575
 public:
Guolin Ke's avatar
Guolin Ke committed
576
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
577
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
578
579
580
581
582
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
583
584

 private:
Guolin Ke's avatar
Guolin Ke committed
585
586
587
588
589
590
591
592
593
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
594
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
595
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
596
597
}

Guolin Ke's avatar
Guolin Ke committed
598
int LGBM_DatasetCreateFromFile(const char* filename,
599
600
601
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
602
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
603
604
  auto param = Config::Str2Map(parameters);
  Config config;
605
606
607
608
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
609
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
610
  if (reference == nullptr) {
611
    if (Network::num_machines() == 1) {
612
      *out = loader.LoadFromFile(filename);
613
    } else {
614
      *out = loader.LoadFromFile(filename, Network::rank(), Network::num_machines());
615
    }
Guolin Ke's avatar
Guolin Ke committed
616
  } else {
617
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
618
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
619
  }
620
  API_END();
Guolin Ke's avatar
Guolin Ke committed
621
622
}

623

Guolin Ke's avatar
Guolin Ke committed
624
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
625
626
627
628
629
630
631
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
632
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
633
634
  auto param = Config::Str2Map(parameters);
  Config config;
635
636
637
638
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
639
  DatasetLoader loader(config, nullptr, 1, nullptr);
640
641
642
643
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
644
645
}

646

Guolin Ke's avatar
Guolin Ke committed
647
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
648
649
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
650
651
652
653
654
655
656
657
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
658
int LGBM_DatasetPushRows(DatasetHandle dataset,
659
660
661
662
663
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
664
665
666
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
667
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
668
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
669
  for (int i = 0; i < nrow; ++i) {
670
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
671
672
673
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
674
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
675
  }
676
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
677
678
679
680
681
682
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
683
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
684
685
686
687
688
689
690
691
692
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
693
694
695
696
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
697
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
698
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
699
  for (int i = 0; i < nrow; ++i) {
700
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
701
702
703
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
704
                          static_cast<data_size_t>(start_row + i), one_row);
705
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
706
  }
707
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
708
709
710
711
712
713
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
714
int LGBM_DatasetCreateFromMat(const void* data,
715
716
717
718
719
720
721
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
743
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
744
745
  auto param = Config::Str2Map(parameters);
  Config config;
746
747
748
749
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
750
  std::unique_ptr<Dataset> ret;
751
752
753
754
755
756
757
758
759
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
760

Guolin Ke's avatar
Guolin Ke committed
761
762
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
763
    Random rand(config.data_random_seed);
764
765
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
766
    sample_cnt = static_cast<int>(sample_indices.size());
767
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
768
    std::vector<std::vector<int>> sample_idx(ncol);
769
770
771

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
772
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
773
      auto idx = sample_indices[i];
774
775
776
777
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
778

779
780
781
782
783
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
784
        }
Guolin Ke's avatar
Guolin Ke committed
785
786
      }
    }
Guolin Ke's avatar
Guolin Ke committed
787
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
788
789
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
790
                                            ncol,
791
                                            Common::VectorSize<double>(sample_values).data(),
792
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
793
  } else {
794
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
795
    ret->CreateValid(
796
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
797
  }
798
799
800
801
802
803
804
805
806
807
808
809
810
811
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
812
813
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
814
  *out = ret.release();
815
  API_END();
816
817
}

Guolin Ke's avatar
Guolin Ke committed
818
int LGBM_DatasetCreateFromCSR(const void* indptr,
819
820
821
822
823
824
825
826
827
828
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
829
  API_BEGIN();
830
831
832
833
834
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
835
836
  auto param = Config::Str2Map(parameters);
  Config config;
837
838
839
840
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
841
  std::unique_ptr<Dataset> ret;
842
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
843
844
845
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
846
847
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
848
    auto sample_indices = rand.Sample(nrow, sample_cnt);
849
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
850
851
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
852
853
854
855
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Nikita Titov's avatar
Nikita Titov committed
856
        CHECK_LT(inner_data.first, num_col);
Guolin Ke's avatar
Guolin Ke committed
857
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
858
859
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
860
861
862
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
863
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
864
865
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
866
                                            static_cast<int>(num_col),
867
868
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
869
  } else {
870
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
871
    ret->CreateValid(
872
      reinterpret_cast<const Dataset*>(reference));
873
  }
874
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
875
  #pragma omp parallel for schedule(static)
876
  for (int i = 0; i < nindptr - 1; ++i) {
877
    OMP_LOOP_EX_BEGIN();
878
879
880
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
881
    OMP_LOOP_EX_END();
882
  }
883
  OMP_THROW_EX();
884
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
885
  *out = ret.release();
886
  API_END();
887
888
}

889
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
890
891
892
893
894
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
895
  API_BEGIN();
896
897
898
899
900
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
Nikita Titov's avatar
Nikita Titov committed
924
        CHECK_LT(inner_data.first, num_col);
925
926
927
928
929
930
931
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
932
933
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
934
                                            static_cast<int>(num_col),
935
936
937
938
939
940
941
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
942

943
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
944
945
  std::vector<std::pair<int, double>> thread_buffer;
  #pragma omp parallel for schedule(static) private(thread_buffer)
946
947
948
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
949
      const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
950
951
      get_row_fun(i, thread_buffer);
      ret->PushOneRow(tid, i, thread_buffer);
952
953
954
955
956
957
958
959
960
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
961
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
962
963
964
965
966
967
968
969
970
971
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
972
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
973
974
  auto param = Config::Str2Map(parameters);
  Config config;
975
976
977
978
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
979
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
980
981
982
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
983
984
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
985
    auto sample_indices = rand.Sample(nrow, sample_cnt);
986
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
987
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
988
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
989
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
990
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
991
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
992
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
993
994
995
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
996
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
997
998
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
999
1000
        }
      }
1001
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1002
    }
1003
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1004
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
1005
1006
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
1007
1008
1009
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
1010
  } else {
1011
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
1012
    ret->CreateValid(
1013
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
1014
  }
1015
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1016
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1017
  for (int i = 0; i < ncol_ptr - 1; ++i) {
1018
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1019
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1020
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
1021
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1022
1023
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
1024
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
Guolin Ke's avatar
Guolin Ke committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    auto bin_mapper = ret->FeatureBinMapper(feature_idx);
    if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
      int row_idx = 0;
      while (row_idx < nrow) {
        auto pair = col_it.NextNonZero();
        row_idx = pair.first;
        // no more data
        if (row_idx < 0) { break; }
        ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
      }
    } else {
      for (int row_idx = 0; row_idx < nrow; ++row_idx) {
        auto val = col_it.Get(row_idx);
        ret->PushOneData(tid, row_idx, group, sub_feature, val);
      }
Guolin Ke's avatar
Guolin Ke committed
1040
    }
1041
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1042
  }
1043
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1044
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1045
  *out = ret.release();
1046
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1047
1048
}

Guolin Ke's avatar
Guolin Ke committed
1049
int LGBM_DatasetGetSubset(
1050
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
1051
1052
1053
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
1054
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
1055
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1056
1057
  auto param = Config::Str2Map(parameters);
  Config config;
1058
1059
1060
1061
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
1062
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
1063
  CHECK_GT(num_used_row_indices, 0);
1064
1065
1066
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
1067
1068
1069
  if (!std::is_sorted(used_row_indices, used_row_indices + num_used_row_indices)) {
    Log::Fatal("used_row_indices should be sorted in Subset");
  }
Guolin Ke's avatar
Guolin Ke committed
1070
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
1071
  ret->CopyFeatureMapperFrom(full_dataset);
1072
  ret->CopySubrow(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
1073
1074
1075
1076
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1077
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
1078
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
1079
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
1080
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
1084
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1085
1086
1087
1088
1089
1090
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1091
int LGBM_DatasetGetFeatureNames(
1092
1093
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
1094
  int* num_feature_names) {
1095
1096
1097
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
1098
1099
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1100
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
1101
1102
1103
1104
  }
  API_END();
}

1105
1106
1107
#ifdef _MSC_VER
  #pragma warning(disable : 4702)
#endif
Guolin Ke's avatar
Guolin Ke committed
1108
int LGBM_DatasetFree(DatasetHandle handle) {
1109
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1110
  delete reinterpret_cast<Dataset*>(handle);
1111
  API_END();
1112
1113
}

Guolin Ke's avatar
Guolin Ke committed
1114
int LGBM_DatasetSaveBinary(DatasetHandle handle,
1115
                           const char* filename) {
1116
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1117
1118
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
1119
  API_END();
1120
1121
}

1122
1123
1124
1125
1126
1127
1128
1129
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1130
int LGBM_DatasetSetField(DatasetHandle handle,
1131
1132
1133
1134
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
1135
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1136
  auto dataset = reinterpret_cast<Dataset*>(handle);
1137
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1138
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
1139
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1140
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
1141
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1142
1143
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
1144
  }
1145
  if (!is_success) { Log::Fatal("Input data type error or field not found"); }
1146
  API_END();
1147
1148
}

Guolin Ke's avatar
Guolin Ke committed
1149
int LGBM_DatasetGetField(DatasetHandle handle,
1150
1151
1152
1153
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
1154
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1155
  auto dataset = reinterpret_cast<Dataset*>(handle);
1156
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1157
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1158
    *out_type = C_API_DTYPE_FLOAT32;
1159
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1160
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1161
    *out_type = C_API_DTYPE_INT32;
1162
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1163
1164
1165
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
Nikita Titov's avatar
Nikita Titov committed
1166
  }
1167
  if (!is_success) { Log::Fatal("Field not found"); }
wxchan's avatar
wxchan committed
1168
  if (*out_ptr == nullptr) { *out_len = 0; }
1169
  API_END();
1170
1171
}

1172
int LGBM_DatasetUpdateParamChecking(const char* old_parameters, const char* new_parameters) {
1173
  API_BEGIN();
1174
1175
1176
1177
1178
  auto old_param = Config::Str2Map(old_parameters);
  Config old_config;
  old_config.Set(old_param);
  auto new_param = Config::Str2Map(new_parameters);
  Booster::CheckDatasetResetConfig(old_config, new_param);
1179
1180
1181
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1182
int LGBM_DatasetGetNumData(DatasetHandle handle,
1183
                           int* out) {
1184
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1185
1186
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1187
  API_END();
1188
1189
}

Guolin Ke's avatar
Guolin Ke committed
1190
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1191
                              int* out) {
1192
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1193
1194
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1195
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1196
}
1197

1198
1199
1200
1201
1202
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
1203
  target_d->AddFeaturesFrom(source_d);
1204
1205
1206
  API_END();
}

1207
1208
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1209
int LGBM_BoosterCreate(const DatasetHandle train_data,
1210
1211
                       const char* parameters,
                       BoosterHandle* out) {
1212
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1213
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1214
1215
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1216
  API_END();
1217
1218
}

Guolin Ke's avatar
Guolin Ke committed
1219
int LGBM_BoosterCreateFromModelfile(
1220
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1221
  int* out_num_iterations,
1222
  BoosterHandle* out) {
1223
  API_BEGIN();
1224
  LocaleContext withLocaleContext("C");
wxchan's avatar
wxchan committed
1225
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1226
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1227
  *out = ret.release();
1228
  API_END();
1229
1230
}

Guolin Ke's avatar
Guolin Ke committed
1231
int LGBM_BoosterLoadModelFromString(
1232
1233
1234
1235
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
1236
  LocaleContext withLocaleContext("C");
wxchan's avatar
wxchan committed
1237
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1238
1239
1240
1241
1242
1243
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1244
1245
1246
#ifdef _MSC_VER
  #pragma warning(disable : 4702)
#endif
Guolin Ke's avatar
Guolin Ke committed
1247
int LGBM_BoosterFree(BoosterHandle handle) {
1248
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1249
  delete reinterpret_cast<Booster*>(handle);
1250
  API_END();
1251
1252
}

1253
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1254
1255
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1256
  ref_booster->ShuffleModels(start_iter, end_iter);
1257
1258
1259
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1260
int LGBM_BoosterMerge(BoosterHandle handle,
1261
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1262
1263
1264
1265
1266
1267
1268
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1269
int LGBM_BoosterAddValidData(BoosterHandle handle,
1270
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1271
1272
1273
1274
1275
1276
1277
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1278
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1279
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1280
1281
1282
1283
1284
1285
1286
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1287
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1288
1289
1290
1291
1292
1293
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1294
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1295
1296
1297
1298
1299
1300
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1301
1302
1303
1304
1305
1306
1307
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1308
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1309
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1310
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1311
1312
1313
1314
1315
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1316
  API_END();
1317
1318
}

Guolin Ke's avatar
Guolin Ke committed
1319
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1320
1321
1322
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1323
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1324
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1325
  #ifdef SCORE_T_USE_DOUBLE
1326
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1327
  #else
1328
1329
1330
1331
1332
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1333
  #endif
1334
  API_END();
1335
1336
}

Guolin Ke's avatar
Guolin Ke committed
1337
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1338
1339
1340
1341
1342
1343
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1344
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1345
1346
1347
1348
1349
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1350

1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1365
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1366
1367
1368
1369
1370
1371
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

1372
1373
1374
1375
1376
1377
int LGBM_BoosterGetEvalNames(BoosterHandle handle,
                             const int len,
                             int* out_len,
                             const size_t buffer_len,
                             size_t* out_buffer_len,
                             char** out_strs) {
wxchan's avatar
wxchan committed
1378
1379
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1380
  *out_len = ref_booster->GetEvalNames(out_strs, len, buffer_len, out_buffer_len);
wxchan's avatar
wxchan committed
1381
1382
1383
  API_END();
}

1384
1385
1386
1387
1388
1389
int LGBM_BoosterGetFeatureNames(BoosterHandle handle,
                                const int len,
                                int* out_len,
                                const size_t buffer_len,
                                size_t* out_buffer_len,
                                char** out_strs) {
wxchan's avatar
wxchan committed
1390
1391
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1392
  *out_len = ref_booster->GetFeatureNames(out_strs, len, buffer_len, out_buffer_len);
wxchan's avatar
wxchan committed
1393
1394
1395
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1396
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1397
1398
1399
1400
1401
1402
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1403
int LGBM_BoosterGetEval(BoosterHandle handle,
1404
1405
1406
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1407
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1408
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1409
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1410
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1411
  *out_len = static_cast<int>(result_buf.size());
1412
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1413
    (out_results)[i] = static_cast<double>(result_buf[i]);
1414
  }
1415
  API_END();
1416
1417
}

Guolin Ke's avatar
Guolin Ke committed
1418
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1419
1420
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1421
1422
1423
1424
1425
1426
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1427
int LGBM_BoosterGetPredict(BoosterHandle handle,
1428
1429
1430
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1431
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1432
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1433
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1434
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1435
1436
}

Guolin Ke's avatar
Guolin Ke committed
1437
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1438
1439
1440
1441
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1442
                               const char* parameter,
1443
                               const char* result_filename) {
1444
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1445
1446
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1447
1448
1449
1450
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1451
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1452
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1453
                       config, result_filename);
1454
  API_END();
1455
1456
}

Guolin Ke's avatar
Guolin Ke committed
1457
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1458
1459
1460
1461
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1462
1463
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1464
1465
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1466
1467
1468
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1469
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1470
1471
1472
1473
1474
1475
1476
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
1477
                              int64_t num_col,
1478
1479
                              int predict_type,
                              int num_iteration,
1480
                              const char* parameter,
1481
1482
                              int64_t* out_len,
                              double* out_result) {
1483
  API_BEGIN();
1484
1485
1486
1487
1488
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1489
1490
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1491
1492
1493
1494
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1495
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1496
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1497
  int nrow = static_cast<int>(nindptr - 1);
1498
  ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1499
                       config, out_result, out_len);
1500
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1501
}
1502

1503
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1504
1505
1506
1507
1508
1509
1510
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
1511
                                       int64_t num_col,
1512
1513
1514
1515
1516
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1517
  API_BEGIN();
1518
1519
1520
1521
1522
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1523
1524
1525
1526
1527
1528
1529
1530
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1531
  ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
1532
1533
1534
1535
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1536
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1547
                              const char* parameter,
1548
1549
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1550
1551
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1552
1553
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1554
1555
1556
1557
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
1558
  int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
1559
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1560
1561
1562
1563
1564
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1565
1566
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
Guolin Ke's avatar
Guolin Ke committed
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
      [&iterators, ncol](int i) {
        std::vector<std::pair<int, double>> one_row;
        one_row.reserve(ncol);
        const int tid = omp_get_thread_num();
        for (int j = 0; j < ncol; ++j) {
          auto val = iterators[tid][j].Get(i);
          if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
            one_row.emplace_back(j, val);
          }
        }
        return one_row;
      };
1579
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
cbecker's avatar
cbecker committed
1580
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1581
1582
1583
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1584
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1585
1586
1587
1588
1589
1590
1591
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1592
                              const char* parameter,
1593
1594
                              int64_t* out_len,
                              double* out_result) {
1595
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1596
1597
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1598
1599
1600
1601
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1602
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1603
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1604
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1605
                       config, out_result, out_len);
1606
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1607
}
1608

1609
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1610
1611
1612
1613
1614
1615
1616
1617
1618
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1619
1620
1621
1622
1623
1624
1625
1626
1627
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1628
  ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
1629
1630
1631
1632
  API_END();
}


1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1652
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
1653
1654
1655
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1656
int LGBM_BoosterSaveModel(BoosterHandle handle,
1657
                          int start_iteration,
1658
1659
                          int num_iteration,
                          const char* filename) {
1660
  API_BEGIN();
1661
  LocaleContext withLocaleContext("C");
Guolin Ke's avatar
Guolin Ke committed
1662
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1663
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1664
1665
1666
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1667
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1668
                                  int start_iteration,
1669
                                  int num_iteration,
1670
                                  int64_t buffer_len,
1671
                                  int64_t* out_len,
1672
                                  char* out_str) {
1673
  API_BEGIN();
1674
  LocaleContext withLocaleContext("C");
1675
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1676
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1677
  *out_len = static_cast<int64_t>(model.size()) + 1;
1678
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1679
    std::memcpy(out_str, model.c_str(), *out_len);
1680
1681
1682
1683
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1684
int LGBM_BoosterDumpModel(BoosterHandle handle,
1685
                          int start_iteration,
1686
                          int num_iteration,
1687
1688
                          int64_t buffer_len,
                          int64_t* out_len,
1689
                          char* out_str) {
wxchan's avatar
wxchan committed
1690
  API_BEGIN();
1691
  LocaleContext withLocaleContext("C");
wxchan's avatar
wxchan committed
1692
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1693
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1694
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1695
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1696
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1697
  }
1698
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1699
}
1700

Guolin Ke's avatar
Guolin Ke committed
1701
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1702
1703
1704
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1705
1706
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1707
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1708
1709
1710
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1711
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1712
1713
1714
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1715
1716
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1717
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1718
1719
1720
  API_END();
}

1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
                                   double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  double max_value = ref_booster->UpperBoundValue();
  *out_results = max_value;
  API_END();
}

int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
                                   double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  double min_value = ref_booster->LowerBoundValue();
  *out_results = min_value;
  API_END();
}

1752
1753
1754
1755
1756
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1757
  Config config;
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1774
1775
1776
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1777
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1778
  if (num_machines > 1) {
1779
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1780
1781
1782
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1783

Guolin Ke's avatar
Guolin Ke committed
1784
// ---- start of some help functions
1785
1786
1787

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1788
  if (data_type == C_API_DTYPE_FLOAT32) {
1789
1790
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1791
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1792
        std::vector<double> ret(num_col);
1793
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1794
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1795
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1796
1797
1798
1799
        }
        return ret;
      };
    } else {
1800
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1801
        std::vector<double> ret(num_col);
1802
        for (int i = 0; i < num_col; ++i) {
1803
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1804
1805
1806
1807
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1808
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1809
1810
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1811
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1812
        std::vector<double> ret(num_col);
1813
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1814
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1815
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1816
1817
1818
1819
        }
        return ret;
      };
    } else {
1820
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1821
        std::vector<double> ret(num_col);
1822
        for (int i = 0; i < num_col; ++i) {
1823
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1824
1825
1826
1827
1828
        }
        return ret;
      };
    }
  }
1829
  Log::Fatal("Unknown data type in RowFunctionFromDenseMatric");
1830
  return nullptr;
1831
1832
1833
1834
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1835
1836
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1837
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1838
1839
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
Guolin Ke's avatar
Guolin Ke committed
1840
      ret.reserve(raw_values.size());
Guolin Ke's avatar
Guolin Ke committed
1841
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1842
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1843
          ret.emplace_back(i, raw_values[i]);
1844
        }
Guolin Ke's avatar
Guolin Ke committed
1845
1846
1847
      }
      return ret;
    };
1848
  }
Guolin Ke's avatar
Guolin Ke committed
1849
  return nullptr;
1850
1851
}

1852
1853
1854
1855
1856
1857
1858
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
Guolin Ke's avatar
Guolin Ke committed
1859
    ret.reserve(raw_values.size());
1860
1861
1862
1863
1864
1865
1866
1867
1868
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1869
std::function<std::vector<std::pair<int, double>>(int idx)>
1870
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1871
  if (data_type == C_API_DTYPE_FLOAT32) {
1872
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1873
    if (indptr_type == C_API_DTYPE_INT32) {
1874
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1875
      return [=] (int idx) {
1876
1877
1878
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1879
1880
1881
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1882
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1883
          ret.emplace_back(indices[i], data_ptr[i]);
1884
1885
1886
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1887
    } else if (indptr_type == C_API_DTYPE_INT64) {
1888
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1889
      return [=] (int idx) {
1890
1891
1892
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1893
1894
1895
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1896
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1897
          ret.emplace_back(indices[i], data_ptr[i]);
1898
1899
1900
1901
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1902
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1903
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1904
    if (indptr_type == C_API_DTYPE_INT32) {
1905
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1906
      return [=] (int idx) {
1907
1908
1909
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1910
1911
1912
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1913
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1914
          ret.emplace_back(indices[i], data_ptr[i]);
1915
1916
1917
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1918
    } else if (indptr_type == C_API_DTYPE_INT64) {
1919
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1920
      return [=] (int idx) {
1921
1922
1923
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1924
1925
1926
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1927
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1928
          ret.emplace_back(indices[i], data_ptr[i]);
1929
1930
1931
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1932
1933
    }
  }
1934
  Log::Fatal("Unknown data type in RowFunctionFromCSR");
1935
  return nullptr;
1936
1937
}

Guolin Ke's avatar
Guolin Ke committed
1938
std::function<std::pair<int, double>(int idx)>
1939
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1940
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1941
  if (data_type == C_API_DTYPE_FLOAT32) {
1942
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1943
    if (col_ptr_type == C_API_DTYPE_INT32) {
1944
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1945
1946
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1947
1948
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1949
1950
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1951
        }
Guolin Ke's avatar
Guolin Ke committed
1952
1953
1954
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1955
      };
Guolin Ke's avatar
Guolin Ke committed
1956
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1957
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1958
1959
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1960
1961
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1962
1963
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1964
        }
Guolin Ke's avatar
Guolin Ke committed
1965
1966
1967
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1968
      };
Guolin Ke's avatar
Guolin Ke committed
1969
    }
Guolin Ke's avatar
Guolin Ke committed
1970
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1971
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1972
    if (col_ptr_type == C_API_DTYPE_INT32) {
1973
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1974
1975
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1976
1977
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1978
1979
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1980
        }
Guolin Ke's avatar
Guolin Ke committed
1981
1982
1983
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1984
      };
Guolin Ke's avatar
Guolin Ke committed
1985
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1986
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1987
1988
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1989
1990
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1991
1992
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1993
        }
Guolin Ke's avatar
Guolin Ke committed
1994
1995
1996
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1997
      };
Guolin Ke's avatar
Guolin Ke committed
1998
1999
    }
  }
2000
  Log::Fatal("Unknown data type in CSC matrix");
2001
  return nullptr;
2002
2003
}

Guolin Ke's avatar
Guolin Ke committed
2004
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
2005
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
2006
2007
2008
2009
2010
2011
2012
2013
2014
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
2015
    }
Guolin Ke's avatar
Guolin Ke committed
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
2033
    }
Guolin Ke's avatar
Guolin Ke committed
2034
2035
2036
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
2037
  }
Guolin Ke's avatar
Guolin Ke committed
2038
}