c_api.cpp 67 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

49
50
51
52
53
54
55
56
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

Guolin Ke's avatar
Guolin Ke committed
57
  SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) {
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
Guolin Ke's avatar
Guolin Ke committed
74
    predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
75
                                   early_stop_, early_stop_freq_, early_stop_margin_));
Guolin Ke's avatar
Guolin Ke committed
76
    num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
77
    predict_function = predictor_->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
78
    num_total_model_ = boosting->NumberOfTotalModel();
79
80
  }
  ~SingleRowPredictor() {}
Guolin Ke's avatar
Guolin Ke committed
81
  bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) {
82
83
84
85
    return early_stop_ != config.pred_early_stop ||
      early_stop_freq_ != config.pred_early_stop_freq ||
      early_stop_margin_ != config.pred_early_stop_margin ||
      iter_ != iter ||
Guolin Ke's avatar
Guolin Ke committed
86
      num_total_model_ != boosting->NumberOfTotalModel();
87
  }
Guolin Ke's avatar
Guolin Ke committed
88

89
90
91
92
93
94
95
96
97
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
98
class Booster {
Nikita Titov's avatar
Nikita Titov committed
99
 public:
Guolin Ke's avatar
Guolin Ke committed
100
  explicit Booster(const char* filename) {
101
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
102
103
  }

Guolin Ke's avatar
Guolin Ke committed
104
  Booster(const Dataset* train_data,
105
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
106
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
107
    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
112
    if (config_.input_model.size() > 0) {
113
114
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
115
    }
Guolin Ke's avatar
Guolin Ke committed
116

Guolin Ke's avatar
Guolin Ke committed
117
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
118

119
120
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
121
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
122
    if (config_.tree_learner == std::string("feature")) {
123
      Log::Fatal("Do not support feature parallel in c api");
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
126
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
127
      config_.tree_learner = "serial";
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
130
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
131
132
133
134
135
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  }

  ~Booster() {
  }
140

141
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
142
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
143
144
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
152
153
154
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
155
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
156
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
157
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
162
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
163
164
165
166
167
168
169
170
171
172
173
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
174
175
176
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
177
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
178
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
179
    if (param.count("num_class")) {
180
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
181
    }
Guolin Ke's avatar
Guolin Ke committed
182
183
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
184
    }
Guolin Ke's avatar
Guolin Ke committed
185
    if (param.count("metric")) {
186
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
187
    }
Guolin Ke's avatar
Guolin Ke committed
188
189

    config_.Set(param);
190
191
192
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
193
194
195

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
196
197
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
198
199
200
201
202
203
204
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
205
206
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
207
    }
Guolin Ke's avatar
Guolin Ke committed
208

Guolin Ke's avatar
Guolin Ke committed
209
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
210
211
212
213
214
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
215
216
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
217
218
219
220
221
222
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
223
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
224
  }
Guolin Ke's avatar
Guolin Ke committed
225

226
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
227
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
228
    return boosting_->TrainOneIter(nullptr, nullptr);
229
230
  }

Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
235
236
237
238
239
240
241
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

242
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
243
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
244
    return boosting_->TrainOneIter(gradients, hessians);
245
246
  }

wxchan's avatar
wxchan committed
247
248
249
250
251
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

252
  void PredictSingleRow(int num_iteration, int predict_type, int ncol,
253
254
255
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
256
257
258
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
259
    }
260
    std::lock_guard<std::mutex> lock(mutex_);
261
    if (single_row_predictor_[predict_type].get() == nullptr ||
Guolin Ke's avatar
Guolin Ke committed
262
263
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
264
                                                                       config, num_iteration));
265
266
267
    }
    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
268
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
269

270
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
271
272
273
  }


274
  void Predict(int num_iteration, int predict_type, int nrow, int ncol,
Guolin Ke's avatar
Guolin Ke committed
275
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
276
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
277
               double* out_result, int64_t* out_len) {
278
279
280
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
281
    }
wxchan's avatar
wxchan committed
282
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
283
284
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
285
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
286
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
287
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
288
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
289
      is_raw_score = true;
290
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
291
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
292
293
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
294
    }
Guolin Ke's avatar
Guolin Ke committed
295

Guolin Ke's avatar
Guolin Ke committed
296
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
297
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
298
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
299
    auto pred_fun = predictor.GetPredictFunction();
300
301
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
302
    for (int i = 0; i < nrow; ++i) {
303
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
304
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
305
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
306
      pred_fun(one_row, pred_wrt_ptr);
307
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
308
    }
309
    OMP_THROW_EX();
310
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
311
312
313
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
314
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
315
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
316
317
318
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
319
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
324
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
325
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
326
327
328
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
329
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
330
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
331
    bool bool_data_has_header = data_has_header > 0 ? true : false;
332
    predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
Guolin Ke's avatar
Guolin Ke committed
333
334
  }

Guolin Ke's avatar
Guolin Ke committed
335
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
336
337
338
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

339
340
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
341
  }
342

343
  void LoadModelFromString(const char* model_str) {
344
345
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
346
347
  }

348
349
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
350
351
  }

352
  std::string DumpModel(int start_iteration, int num_iteration) {
353
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
354
  }
355

356
357
358
359
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
360
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
361
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
366
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
367
368
  }

369
  void ShuffleModels(int start_iter, int end_iter) {
370
    std::lock_guard<std::mutex> lock(mutex_);
371
    boosting_->ShuffleModels(start_iter, end_iter);
372
373
  }

wxchan's avatar
wxchan committed
374
375
376
377
378
379
380
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
381

wxchan's avatar
wxchan committed
382
383
384
385
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
386
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
387
388
389
390
391
392
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
393
394
395
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
396
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
397
398
399
400
401
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
402
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
403

Nikita Titov's avatar
Nikita Titov committed
404
 private:
wxchan's avatar
wxchan committed
405
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
406
  std::unique_ptr<Boosting> boosting_;
407
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
408

Guolin Ke's avatar
Guolin Ke committed
409
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
410
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
411
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
412
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
413
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
414
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
415
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
416
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
417
418
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
419
420
};

421
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
422
423
424

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
425
426
427
428
429
430
431
432
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

433
434
435
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
436
437
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
438
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
439
440
441

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
442
 public:
Guolin Ke's avatar
Guolin Ke committed
443
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
444
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448
449
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
450
451

 private:
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
457
458
459
460
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
461
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
462
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
463
464
}

Guolin Ke's avatar
Guolin Ke committed
465
int LGBM_DatasetCreateFromFile(const char* filename,
466
467
468
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
469
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
470
471
  auto param = Config::Str2Map(parameters);
  Config config;
472
473
474
475
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
476
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
477
  if (reference == nullptr) {
478
479
480
481
482
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
483
  } else {
484
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
485
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
486
  }
487
  API_END();
Guolin Ke's avatar
Guolin Ke committed
488
489
}

490

Guolin Ke's avatar
Guolin Ke committed
491
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
492
493
494
495
496
497
498
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
499
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
500
501
  auto param = Config::Str2Map(parameters);
  Config config;
502
503
504
505
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
506
  DatasetLoader loader(config, nullptr, 1, nullptr);
507
508
509
510
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
511
512
}

513

Guolin Ke's avatar
Guolin Ke committed
514
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
515
516
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
517
518
519
520
521
522
523
524
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
525
int LGBM_DatasetPushRows(DatasetHandle dataset,
526
527
528
529
530
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
531
532
533
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
534
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
535
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
536
  for (int i = 0; i < nrow; ++i) {
537
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
538
539
540
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
541
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
542
  }
543
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
544
545
546
547
548
549
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
550
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
551
552
553
554
555
556
557
558
559
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
560
561
562
563
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
564
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
565
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
566
  for (int i = 0; i < nrow; ++i) {
567
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
568
569
570
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
571
                          static_cast<data_size_t>(start_row + i), one_row);
572
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
573
  }
574
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
575
576
577
578
579
580
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
581
int LGBM_DatasetCreateFromMat(const void* data,
582
583
584
585
586
587
588
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
610
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
611
612
  auto param = Config::Str2Map(parameters);
  Config config;
613
614
615
616
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
617
  std::unique_ptr<Dataset> ret;
618
619
620
621
622
623
624
625
626
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
627

Guolin Ke's avatar
Guolin Ke committed
628
629
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
630
    Random rand(config.data_random_seed);
631
632
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
633
    sample_cnt = static_cast<int>(sample_indices.size());
634
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
635
    std::vector<std::vector<int>> sample_idx(ncol);
636
637
638

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
639
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
640
      auto idx = sample_indices[i];
641
642
643
644
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
645

646
647
648
649
650
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
651
        }
Guolin Ke's avatar
Guolin Ke committed
652
653
      }
    }
Guolin Ke's avatar
Guolin Ke committed
654
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
655
656
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
657
                                            ncol,
658
                                            Common::VectorSize<double>(sample_values).data(),
659
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
660
  } else {
661
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
662
    ret->CreateValid(
663
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
664
  }
665
666
667
668
669
670
671
672
673
674
675
676
677
678
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
679
680
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
681
  *out = ret.release();
682
  API_END();
683
684
}

Guolin Ke's avatar
Guolin Ke committed
685
int LGBM_DatasetCreateFromCSR(const void* indptr,
686
687
688
689
690
691
692
693
694
695
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
696
  API_BEGIN();
697
698
699
700
701
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
702
703
  auto param = Config::Str2Map(parameters);
  Config config;
704
705
706
707
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
708
  std::unique_ptr<Dataset> ret;
709
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
710
711
712
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
713
714
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
715
    auto sample_indices = rand.Sample(nrow, sample_cnt);
716
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
717
718
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
719
720
721
722
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
723
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
724
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
725
726
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
727
728
729
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
730
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
731
732
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
733
                                            static_cast<int>(num_col),
734
735
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
736
  } else {
737
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
738
    ret->CreateValid(
739
      reinterpret_cast<const Dataset*>(reference));
740
  }
741
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
742
  #pragma omp parallel for schedule(static)
743
  for (int i = 0; i < nindptr - 1; ++i) {
744
    OMP_LOOP_EX_BEGIN();
745
746
747
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
748
    OMP_LOOP_EX_END();
749
  }
750
  OMP_THROW_EX();
751
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
752
  *out = ret.release();
753
  API_END();
754
755
}

756
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
757
758
759
760
761
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
762
  API_BEGIN();
763
764
765
766
767
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
799
800
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
801
                                            static_cast<int>(num_col),
802
803
804
805
806
807
808
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
809

810
811
812
813
814
815
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
816
817
818
      const int tid = omp_get_thread_num();
      get_row_fun(i, threadBuffer);
      ret->PushOneRow(tid, i, threadBuffer);
819
820
821
822
823
824
825
826
827
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
828
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
829
830
831
832
833
834
835
836
837
838
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
839
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
840
841
  auto param = Config::Str2Map(parameters);
  Config config;
842
843
844
845
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
846
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
847
848
849
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
850
851
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
852
    auto sample_indices = rand.Sample(nrow, sample_cnt);
853
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
854
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
855
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
856
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
857
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
858
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
859
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
860
861
862
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
863
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
864
865
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
866
867
        }
      }
868
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
869
    }
870
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
871
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
872
873
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
874
875
876
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
877
  } else {
878
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
879
    ret->CreateValid(
880
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
881
  }
882
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
883
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
884
  for (int i = 0; i < ncol_ptr - 1; ++i) {
885
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
886
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
887
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
888
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
889
890
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
891
892
893
894
895
896
897
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
898
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
899
    }
900
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
901
  }
902
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
903
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
904
  *out = ret.release();
905
  API_END();
Guolin Ke's avatar
Guolin Ke committed
906
907
}

Guolin Ke's avatar
Guolin Ke committed
908
int LGBM_DatasetGetSubset(
909
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
910
911
912
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
913
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
914
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
915
916
  auto param = Config::Str2Map(parameters);
  Config config;
917
918
919
920
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
921
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
922
  CHECK(num_used_row_indices > 0);
923
924
925
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
926
927
928
  if (!std::is_sorted(used_row_indices, used_row_indices + num_used_row_indices)) {
    Log::Fatal("used_row_indices should be sorted in Subset");
  }
Guolin Ke's avatar
Guolin Ke committed
929
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
930
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
931
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
932
933
934
935
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
936
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
937
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
938
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
939
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
940
941
942
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
943
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
944
945
946
947
948
949
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
950
int LGBM_DatasetGetFeatureNames(
951
952
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
953
  int* num_feature_names) {
954
955
956
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
957
958
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
959
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
960
961
962
963
  }
  API_END();
}

964
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
965
int LGBM_DatasetFree(DatasetHandle handle) {
966
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
967
  delete reinterpret_cast<Dataset*>(handle);
968
  API_END();
969
970
}

Guolin Ke's avatar
Guolin Ke committed
971
int LGBM_DatasetSaveBinary(DatasetHandle handle,
972
                           const char* filename) {
973
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
974
975
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
976
  API_END();
977
978
}

979
980
981
982
983
984
985
986
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
987
int LGBM_DatasetSetField(DatasetHandle handle,
988
989
990
991
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
992
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
993
  auto dataset = reinterpret_cast<Dataset*>(handle);
994
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
995
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
996
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
997
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
998
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
999
1000
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
1001
  }
1002
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
1003
  API_END();
1004
1005
}

Guolin Ke's avatar
Guolin Ke committed
1006
int LGBM_DatasetGetField(DatasetHandle handle,
1007
1008
1009
1010
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
1011
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1012
  auto dataset = reinterpret_cast<Dataset*>(handle);
1013
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1014
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1015
    *out_type = C_API_DTYPE_FLOAT32;
1016
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1017
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1018
    *out_type = C_API_DTYPE_INT32;
1019
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1020
1021
1022
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
1023
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
1024
1025
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
1026
  }
1027
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
1028
  if (*out_ptr == nullptr) { *out_len = 0; }
1029
  API_END();
1030
1031
}

1032
1033
1034
1035
1036
1037
1038
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1039
int LGBM_DatasetGetNumData(DatasetHandle handle,
1040
                           int* out) {
1041
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1042
1043
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1044
  API_END();
1045
1046
}

Guolin Ke's avatar
Guolin Ke committed
1047
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1048
                              int* out) {
1049
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1050
1051
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1052
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1053
}
1054

1055
1056
1057
1058
1059
1060
1061
1062
1063
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1064
1065
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1066
int LGBM_BoosterCreate(const DatasetHandle train_data,
1067
1068
                       const char* parameters,
                       BoosterHandle* out) {
1069
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1070
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1071
1072
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1073
  API_END();
1074
1075
}

Guolin Ke's avatar
Guolin Ke committed
1076
int LGBM_BoosterCreateFromModelfile(
1077
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1078
  int* out_num_iterations,
1079
  BoosterHandle* out) {
1080
  API_BEGIN();
wxchan's avatar
wxchan committed
1081
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1082
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1083
  *out = ret.release();
1084
  API_END();
1085
1086
}

Guolin Ke's avatar
Guolin Ke committed
1087
int LGBM_BoosterLoadModelFromString(
1088
1089
1090
1091
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1092
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1093
1094
1095
1096
1097
1098
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1099
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1100
int LGBM_BoosterFree(BoosterHandle handle) {
1101
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1102
  delete reinterpret_cast<Booster*>(handle);
1103
  API_END();
1104
1105
}

1106
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1107
1108
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1109
  ref_booster->ShuffleModels(start_iter, end_iter);
1110
1111
1112
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1113
int LGBM_BoosterMerge(BoosterHandle handle,
1114
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1115
1116
1117
1118
1119
1120
1121
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1122
int LGBM_BoosterAddValidData(BoosterHandle handle,
1123
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1124
1125
1126
1127
1128
1129
1130
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1131
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1132
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1133
1134
1135
1136
1137
1138
1139
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1140
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1141
1142
1143
1144
1145
1146
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1147
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1148
1149
1150
1151
1152
1153
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1154
1155
1156
1157
1158
1159
1160
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1161
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1162
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1163
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1164
1165
1166
1167
1168
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1169
  API_END();
1170
1171
}

Guolin Ke's avatar
Guolin Ke committed
1172
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1173
1174
1175
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1176
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1177
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1178
  #ifdef SCORE_T_USE_DOUBLE
1179
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1180
  #else
1181
1182
1183
1184
1185
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1186
  #endif
1187
  API_END();
1188
1189
}

Guolin Ke's avatar
Guolin Ke committed
1190
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1191
1192
1193
1194
1195
1196
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1197
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1198
1199
1200
1201
1202
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1203

1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1218
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1219
1220
1221
1222
1223
1224
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1225
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1226
1227
1228
1229
1230
1231
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1232
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1233
1234
1235
1236
1237
1238
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1239
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1240
1241
1242
1243
1244
1245
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1246
int LGBM_BoosterGetEval(BoosterHandle handle,
1247
1248
1249
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1250
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1251
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1252
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1253
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1254
  *out_len = static_cast<int>(result_buf.size());
1255
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1256
    (out_results)[i] = static_cast<double>(result_buf[i]);
1257
  }
1258
  API_END();
1259
1260
}

Guolin Ke's avatar
Guolin Ke committed
1261
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1262
1263
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1264
1265
1266
1267
1268
1269
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1270
int LGBM_BoosterGetPredict(BoosterHandle handle,
1271
1272
1273
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1274
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1275
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1276
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1277
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1278
1279
}

Guolin Ke's avatar
Guolin Ke committed
1280
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1281
1282
1283
1284
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1285
                               const char* parameter,
1286
                               const char* result_filename) {
1287
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1288
1289
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1290
1291
1292
1293
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1294
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1295
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1296
                       config, result_filename);
1297
  API_END();
1298
1299
}

Guolin Ke's avatar
Guolin Ke committed
1300
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1301
1302
1303
1304
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1305
1306
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1307
1308
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1309
1310
1311
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1312
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1313
1314
1315
1316
1317
1318
1319
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
1320
                              int64_t num_col,
1321
1322
                              int predict_type,
                              int num_iteration,
1323
                              const char* parameter,
1324
1325
                              int64_t* out_len,
                              double* out_result) {
1326
  API_BEGIN();
1327
1328
1329
1330
1331
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1332
1333
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1334
1335
1336
1337
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1338
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1339
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1340
  int nrow = static_cast<int>(nindptr - 1);
1341
  ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1342
                       config, out_result, out_len);
1343
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1344
}
1345

1346
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1347
1348
1349
1350
1351
1352
1353
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
1354
                                       int64_t num_col,
1355
1356
1357
1358
1359
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1360
  API_BEGIN();
1361
1362
1363
1364
1365
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1366
1367
1368
1369
1370
1371
1372
1373
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1374
  ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
1375
1376
1377
1378
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1379
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1390
                              const char* parameter,
1391
1392
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1393
1394
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1395
1396
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1407
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1408
1409
1410
1411
1412
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1413
1414
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1415
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1416
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1417
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1418
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1419
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1420
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1421
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1422
1423
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1424
1425
    return one_row;
  };
1426
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
cbecker's avatar
cbecker committed
1427
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1428
1429
1430
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1431
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1432
1433
1434
1435
1436
1437
1438
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1439
                              const char* parameter,
1440
1441
                              int64_t* out_len,
                              double* out_result) {
1442
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1443
1444
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1445
1446
1447
1448
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1449
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1450
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1451
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1452
                       config, out_result, out_len);
1453
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1454
}
1455

1456
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1457
1458
1459
1460
1461
1462
1463
1464
1465
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1466
1467
1468
1469
1470
1471
1472
1473
1474
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1475
  ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
1476
1477
1478
1479
  API_END();
}


1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1499
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
1500
1501
1502
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1503
int LGBM_BoosterSaveModel(BoosterHandle handle,
1504
                          int start_iteration,
1505
1506
                          int num_iteration,
                          const char* filename) {
1507
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1508
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1509
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1510
1511
1512
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1513
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1514
                                  int start_iteration,
1515
                                  int num_iteration,
1516
                                  int64_t buffer_len,
1517
                                  int64_t* out_len,
1518
                                  char* out_str) {
1519
1520
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1521
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1522
  *out_len = static_cast<int64_t>(model.size()) + 1;
1523
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1524
    std::memcpy(out_str, model.c_str(), *out_len);
1525
1526
1527
1528
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1529
int LGBM_BoosterDumpModel(BoosterHandle handle,
1530
                          int start_iteration,
1531
                          int num_iteration,
1532
1533
                          int64_t buffer_len,
                          int64_t* out_len,
1534
                          char* out_str) {
wxchan's avatar
wxchan committed
1535
1536
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1537
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1538
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1539
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1540
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1541
  }
1542
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1543
}
1544

Guolin Ke's avatar
Guolin Ke committed
1545
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1546
1547
1548
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1549
1550
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1551
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1552
1553
1554
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1555
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1556
1557
1558
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1559
1560
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1561
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1562
1563
1564
  API_END();
}

1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1578
1579
1580
1581
1582
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1583
  Config config;
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1600
1601
1602
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1603
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1604
  if (num_machines > 1) {
1605
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1606
1607
1608
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1609

Guolin Ke's avatar
Guolin Ke committed
1610
// ---- start of some help functions
1611
1612
1613

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1614
  if (data_type == C_API_DTYPE_FLOAT32) {
1615
1616
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1617
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1618
        std::vector<double> ret(num_col);
1619
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1620
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1621
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1622
1623
1624
1625
        }
        return ret;
      };
    } else {
1626
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1627
        std::vector<double> ret(num_col);
1628
        for (int i = 0; i < num_col; ++i) {
1629
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1630
1631
1632
1633
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1634
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1635
1636
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1637
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1638
        std::vector<double> ret(num_col);
1639
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1640
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1641
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1642
1643
1644
1645
        }
        return ret;
      };
    } else {
1646
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1647
        std::vector<double> ret(num_col);
1648
        for (int i = 0; i < num_col; ++i) {
1649
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1650
1651
1652
1653
1654
        }
        return ret;
      };
    }
  }
1655
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1656
1657
1658
1659
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1660
1661
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1662
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1663
1664
1665
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1666
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1667
          ret.emplace_back(i, raw_values[i]);
1668
        }
Guolin Ke's avatar
Guolin Ke committed
1669
1670
1671
      }
      return ret;
    };
1672
  }
Guolin Ke's avatar
Guolin Ke committed
1673
  return nullptr;
1674
1675
}

1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1692
std::function<std::vector<std::pair<int, double>>(int idx)>
1693
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1694
  if (data_type == C_API_DTYPE_FLOAT32) {
1695
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1696
    if (indptr_type == C_API_DTYPE_INT32) {
1697
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1698
      return [=] (int idx) {
1699
1700
1701
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1702
1703
1704
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1705
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1706
          ret.emplace_back(indices[i], data_ptr[i]);
1707
1708
1709
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1710
    } else if (indptr_type == C_API_DTYPE_INT64) {
1711
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1712
      return [=] (int idx) {
1713
1714
1715
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1716
1717
1718
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1719
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1720
          ret.emplace_back(indices[i], data_ptr[i]);
1721
1722
1723
1724
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1725
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1726
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1727
    if (indptr_type == C_API_DTYPE_INT32) {
1728
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1729
      return [=] (int idx) {
1730
1731
1732
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1733
1734
1735
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1736
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1737
          ret.emplace_back(indices[i], data_ptr[i]);
1738
1739
1740
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1741
    } else if (indptr_type == C_API_DTYPE_INT64) {
1742
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1743
      return [=] (int idx) {
1744
1745
1746
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1747
1748
1749
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1750
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1751
          ret.emplace_back(indices[i], data_ptr[i]);
1752
1753
1754
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1755
1756
    }
  }
1757
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1758
1759
}

Guolin Ke's avatar
Guolin Ke committed
1760
std::function<std::pair<int, double>(int idx)>
1761
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1762
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1763
  if (data_type == C_API_DTYPE_FLOAT32) {
1764
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1765
    if (col_ptr_type == C_API_DTYPE_INT32) {
1766
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1767
1768
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1769
1770
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1771
1772
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1773
        }
Guolin Ke's avatar
Guolin Ke committed
1774
1775
1776
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1777
      };
Guolin Ke's avatar
Guolin Ke committed
1778
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1779
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1780
1781
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1782
1783
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1784
1785
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1786
        }
Guolin Ke's avatar
Guolin Ke committed
1787
1788
1789
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1790
      };
Guolin Ke's avatar
Guolin Ke committed
1791
    }
Guolin Ke's avatar
Guolin Ke committed
1792
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1793
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1794
    if (col_ptr_type == C_API_DTYPE_INT32) {
1795
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1796
1797
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1798
1799
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1800
1801
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1802
        }
Guolin Ke's avatar
Guolin Ke committed
1803
1804
1805
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1806
      };
Guolin Ke's avatar
Guolin Ke committed
1807
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1808
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1809
1810
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1811
1812
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1813
1814
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1815
        }
Guolin Ke's avatar
Guolin Ke committed
1816
1817
1818
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1819
      };
Guolin Ke's avatar
Guolin Ke committed
1820
1821
    }
  }
1822
  throw std::runtime_error("Unknown data type in CSC matrix");
1823
1824
}

Guolin Ke's avatar
Guolin Ke committed
1825
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1826
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1827
1828
1829
1830
1831
1832
1833
1834
1835
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1836
    }
Guolin Ke's avatar
Guolin Ke committed
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1854
    }
Guolin Ke's avatar
Guolin Ke committed
1855
1856
1857
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1858
  }
Guolin Ke's avatar
Guolin Ke committed
1859
}