"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b6c71e5e941b885b1a5169d460c430df9d392fa9"
c_api.cpp 67.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

49
50
51
52
53
54
55
56
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

Guolin Ke's avatar
Guolin Ke committed
57
  SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) {
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
Guolin Ke's avatar
Guolin Ke committed
74
    predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
75
                                   early_stop_, early_stop_freq_, early_stop_margin_));
Guolin Ke's avatar
Guolin Ke committed
76
    num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
77
    predict_function = predictor_->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
78
    num_total_model_ = boosting->NumberOfTotalModel();
79
80
  }
  ~SingleRowPredictor() {}
Guolin Ke's avatar
Guolin Ke committed
81
  bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) {
82
83
84
85
    return early_stop_ != config.pred_early_stop ||
      early_stop_freq_ != config.pred_early_stop_freq ||
      early_stop_margin_ != config.pred_early_stop_margin ||
      iter_ != iter ||
Guolin Ke's avatar
Guolin Ke committed
86
      num_total_model_ != boosting->NumberOfTotalModel();
87
  }
Guolin Ke's avatar
Guolin Ke committed
88

89
90
91
92
93
94
95
96
97
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
98
class Booster {
Nikita Titov's avatar
Nikita Titov committed
99
 public:
Guolin Ke's avatar
Guolin Ke committed
100
  explicit Booster(const char* filename) {
101
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
102
103
  }

Guolin Ke's avatar
Guolin Ke committed
104
  Booster(const Dataset* train_data,
105
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
106
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
107
    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
112
    if (config_.input_model.size() > 0) {
113
114
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
115
    }
Guolin Ke's avatar
Guolin Ke committed
116

Guolin Ke's avatar
Guolin Ke committed
117
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
118

119
120
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
121
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
122
    if (config_.tree_learner == std::string("feature")) {
123
      Log::Fatal("Do not support feature parallel in c api");
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
126
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
127
      config_.tree_learner = "serial";
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
130
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
131
132
133
134
135
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  }

  ~Booster() {
  }
140

141
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
142
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
143
144
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
152
153
154
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
155
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
156
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
157
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
162
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
163
164
165
166
167
168
169
170
171
172
173
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
174
175
176
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
177
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
178
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
179
    if (param.count("num_class")) {
180
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
181
    }
Guolin Ke's avatar
Guolin Ke committed
182
183
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
184
    }
Guolin Ke's avatar
Guolin Ke committed
185
    if (param.count("metric")) {
186
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
187
    }
Guolin Ke's avatar
Guolin Ke committed
188
189

    config_.Set(param);
190
191
192
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
193
194
195

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
196
197
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
198
199
200
201
202
203
204
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
205
206
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
207
    }
Guolin Ke's avatar
Guolin Ke committed
208

Guolin Ke's avatar
Guolin Ke committed
209
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
210
211
212
213
214
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
215
216
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
217
218
219
220
221
222
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
223
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
224
  }
Guolin Ke's avatar
Guolin Ke committed
225

226
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
227
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
228
    return boosting_->TrainOneIter(nullptr, nullptr);
229
230
  }

Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
235
236
237
238
239
240
241
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

242
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
243
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
244
    return boosting_->TrainOneIter(gradients, hessians);
245
246
  }

wxchan's avatar
wxchan committed
247
248
249
250
251
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

252
  void PredictSingleRow(int num_iteration, int predict_type, int ncol,
253
254
255
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
256
257
258
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
259
    }
260
    std::lock_guard<std::mutex> lock(mutex_);
261
    if (single_row_predictor_[predict_type].get() == nullptr ||
Guolin Ke's avatar
Guolin Ke committed
262
263
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
264
                                                                       config, num_iteration));
265
266
267
    }
    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
268
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
269

270
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
271
272
273
  }


274
  void Predict(int num_iteration, int predict_type, int nrow, int ncol,
Guolin Ke's avatar
Guolin Ke committed
275
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
276
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
277
               double* out_result, int64_t* out_len) {
278
279
280
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
281
    }
wxchan's avatar
wxchan committed
282
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
283
284
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
285
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
286
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
287
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
288
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
289
      is_raw_score = true;
290
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
291
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
292
293
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
294
    }
Guolin Ke's avatar
Guolin Ke committed
295

Guolin Ke's avatar
Guolin Ke committed
296
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
297
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
298
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
299
    auto pred_fun = predictor.GetPredictFunction();
300
301
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
302
    for (int i = 0; i < nrow; ++i) {
303
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
304
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
305
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
306
      pred_fun(one_row, pred_wrt_ptr);
307
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
308
    }
309
    OMP_THROW_EX();
310
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
311
312
313
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
314
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
315
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
316
317
318
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
319
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
324
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
325
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
326
327
328
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
329
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
330
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
331
    bool bool_data_has_header = data_has_header > 0 ? true : false;
332
    predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
Guolin Ke's avatar
Guolin Ke committed
333
334
  }

Guolin Ke's avatar
Guolin Ke committed
335
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
336
337
338
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

339
340
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
341
  }
342

343
  void LoadModelFromString(const char* model_str) {
344
345
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
346
347
  }

348
349
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
350
351
  }

352
  std::string DumpModel(int start_iteration, int num_iteration) {
353
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
354
  }
355

356
357
358
359
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
360
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
361
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
366
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
367
368
  }

369
  void ShuffleModels(int start_iter, int end_iter) {
370
    std::lock_guard<std::mutex> lock(mutex_);
371
    boosting_->ShuffleModels(start_iter, end_iter);
372
373
  }

wxchan's avatar
wxchan committed
374
375
376
377
378
379
380
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
381

wxchan's avatar
wxchan committed
382
383
384
385
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
386
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
387
388
389
390
391
392
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
393
394
395
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
396
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
397
398
399
400
401
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
402
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
403

Nikita Titov's avatar
Nikita Titov committed
404
 private:
wxchan's avatar
wxchan committed
405
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
406
  std::unique_ptr<Boosting> boosting_;
407
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
408

Guolin Ke's avatar
Guolin Ke committed
409
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
410
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
411
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
412
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
413
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
414
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
415
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
416
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
417
418
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
419
420
};

421
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
422
423
424

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
425
426
427
428
429
430
431
432
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

433
434
435
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
436
437
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
438
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
439
440
441

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
442
 public:
Guolin Ke's avatar
Guolin Ke committed
443
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
444
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448
449
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
450
451

 private:
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
457
458
459
460
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
461
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
462
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
463
464
}

Guolin Ke's avatar
Guolin Ke committed
465
int LGBM_DatasetCreateFromFile(const char* filename,
466
467
468
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
469
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
470
471
  auto param = Config::Str2Map(parameters);
  Config config;
472
473
474
475
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
476
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
477
  if (reference == nullptr) {
478
479
480
481
482
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
483
  } else {
484
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
485
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
486
  }
487
  API_END();
Guolin Ke's avatar
Guolin Ke committed
488
489
}

490

Guolin Ke's avatar
Guolin Ke committed
491
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
492
493
494
495
496
497
498
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
499
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
500
501
  auto param = Config::Str2Map(parameters);
  Config config;
502
503
504
505
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
506
  DatasetLoader loader(config, nullptr, 1, nullptr);
507
508
509
510
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
511
512
}

513

Guolin Ke's avatar
Guolin Ke committed
514
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
515
516
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
517
518
519
520
521
522
523
524
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
525
int LGBM_DatasetPushRows(DatasetHandle dataset,
526
527
528
529
530
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
531
532
533
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
534
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
535
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
536
  for (int i = 0; i < nrow; ++i) {
537
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
538
539
540
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
541
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
542
  }
543
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
544
545
546
547
548
549
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
550
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
551
552
553
554
555
556
557
558
559
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
560
561
562
563
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
564
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
565
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
566
  for (int i = 0; i < nrow; ++i) {
567
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
568
569
570
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
571
                          static_cast<data_size_t>(start_row + i), one_row);
572
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
573
  }
574
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
575
576
577
578
579
580
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
581
int LGBM_DatasetCreateFromMat(const void* data,
582
583
584
585
586
587
588
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
610
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
611
612
  auto param = Config::Str2Map(parameters);
  Config config;
613
614
615
616
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
617
  std::unique_ptr<Dataset> ret;
618
619
620
621
622
623
624
625
626
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
627

Guolin Ke's avatar
Guolin Ke committed
628
629
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
630
    Random rand(config.data_random_seed);
631
632
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
633
    sample_cnt = static_cast<int>(sample_indices.size());
634
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
635
    std::vector<std::vector<int>> sample_idx(ncol);
636
637
638

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
639
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
640
      auto idx = sample_indices[i];
641
642
643
644
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
645

646
647
648
649
650
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
651
        }
Guolin Ke's avatar
Guolin Ke committed
652
653
      }
    }
Guolin Ke's avatar
Guolin Ke committed
654
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
655
656
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
657
                                            ncol,
658
                                            Common::VectorSize<double>(sample_values).data(),
659
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
660
  } else {
661
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
662
    ret->CreateValid(
663
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
664
  }
665
666
667
668
669
670
671
672
673
674
675
676
677
678
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
679
680
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
681
  *out = ret.release();
682
  API_END();
683
684
}

Guolin Ke's avatar
Guolin Ke committed
685
int LGBM_DatasetCreateFromCSR(const void* indptr,
686
687
688
689
690
691
692
693
694
695
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
696
  API_BEGIN();
697
698
699
700
701
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
702
703
  auto param = Config::Str2Map(parameters);
  Config config;
704
705
706
707
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
708
  std::unique_ptr<Dataset> ret;
709
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
710
711
712
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
713
714
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
715
    auto sample_indices = rand.Sample(nrow, sample_cnt);
716
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
717
718
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
719
720
721
722
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
723
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
724
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
725
726
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
727
728
729
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
730
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
731
732
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
733
                                            static_cast<int>(num_col),
734
735
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
736
  } else {
737
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
738
    ret->CreateValid(
739
      reinterpret_cast<const Dataset*>(reference));
740
  }
741
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
742
  #pragma omp parallel for schedule(static)
743
  for (int i = 0; i < nindptr - 1; ++i) {
744
    OMP_LOOP_EX_BEGIN();
745
746
747
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
748
    OMP_LOOP_EX_END();
749
  }
750
  OMP_THROW_EX();
751
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
752
  *out = ret.release();
753
  API_END();
754
755
}

756
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
757
758
759
760
761
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
762
  API_BEGIN();
763
764
765
766
767
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
799
800
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
801
                                            static_cast<int>(num_col),
802
803
804
805
806
807
808
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
809

810
811
812
813
814
815
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
816
817
818
      const int tid = omp_get_thread_num();
      get_row_fun(i, threadBuffer);
      ret->PushOneRow(tid, i, threadBuffer);
819
820
821
822
823
824
825
826
827
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
828
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
829
830
831
832
833
834
835
836
837
838
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
839
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
840
841
  auto param = Config::Str2Map(parameters);
  Config config;
842
843
844
845
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
846
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
847
848
849
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
850
851
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
852
    auto sample_indices = rand.Sample(nrow, sample_cnt);
853
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
854
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
855
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
856
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
857
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
858
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
859
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
860
861
862
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
863
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
864
865
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
866
867
        }
      }
868
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
869
    }
870
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
871
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
872
873
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
874
875
876
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
877
  } else {
878
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
879
    ret->CreateValid(
880
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
881
  }
882
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
883
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
884
  for (int i = 0; i < ncol_ptr - 1; ++i) {
885
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
886
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
887
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
888
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
889
890
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
891
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
Guolin Ke's avatar
Guolin Ke committed
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    auto bin_mapper = ret->FeatureBinMapper(feature_idx);
    if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
      int row_idx = 0;
      while (row_idx < nrow) {
        auto pair = col_it.NextNonZero();
        row_idx = pair.first;
        // no more data
        if (row_idx < 0) { break; }
        ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
      }
    } else {
      for (int row_idx = 0; row_idx < nrow; ++row_idx) {
        auto val = col_it.Get(row_idx);
        ret->PushOneData(tid, row_idx, group, sub_feature, val);
      }
Guolin Ke's avatar
Guolin Ke committed
907
    }
908
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
909
  }
910
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
911
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
912
  *out = ret.release();
913
  API_END();
Guolin Ke's avatar
Guolin Ke committed
914
915
}

Guolin Ke's avatar
Guolin Ke committed
916
int LGBM_DatasetGetSubset(
917
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
918
919
920
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
921
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
922
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
923
924
  auto param = Config::Str2Map(parameters);
  Config config;
925
926
927
928
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
929
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
930
  CHECK(num_used_row_indices > 0);
931
932
933
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
934
935
936
  if (!std::is_sorted(used_row_indices, used_row_indices + num_used_row_indices)) {
    Log::Fatal("used_row_indices should be sorted in Subset");
  }
Guolin Ke's avatar
Guolin Ke committed
937
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
938
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
939
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
940
941
942
943
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
944
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
945
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
946
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
947
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
948
949
950
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
951
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
952
953
954
955
956
957
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
958
int LGBM_DatasetGetFeatureNames(
959
960
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
961
  int* num_feature_names) {
962
963
964
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
965
966
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
967
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
968
969
970
971
  }
  API_END();
}

972
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
973
int LGBM_DatasetFree(DatasetHandle handle) {
974
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
975
  delete reinterpret_cast<Dataset*>(handle);
976
  API_END();
977
978
}

Guolin Ke's avatar
Guolin Ke committed
979
int LGBM_DatasetSaveBinary(DatasetHandle handle,
980
                           const char* filename) {
981
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
982
983
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
984
  API_END();
985
986
}

987
988
989
990
991
992
993
994
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
995
int LGBM_DatasetSetField(DatasetHandle handle,
996
997
998
999
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
1000
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1001
  auto dataset = reinterpret_cast<Dataset*>(handle);
1002
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1003
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
1004
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1005
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
1006
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1007
1008
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
1009
  }
1010
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
1011
  API_END();
1012
1013
}

Guolin Ke's avatar
Guolin Ke committed
1014
int LGBM_DatasetGetField(DatasetHandle handle,
1015
1016
1017
1018
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
1019
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1020
  auto dataset = reinterpret_cast<Dataset*>(handle);
1021
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1022
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1023
    *out_type = C_API_DTYPE_FLOAT32;
1024
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1025
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1026
    *out_type = C_API_DTYPE_INT32;
1027
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1028
1029
1030
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
1031
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
1032
1033
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
1034
  }
1035
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
1036
  if (*out_ptr == nullptr) { *out_len = 0; }
1037
  API_END();
1038
1039
}

1040
1041
1042
1043
1044
1045
1046
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1047
int LGBM_DatasetGetNumData(DatasetHandle handle,
1048
                           int* out) {
1049
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1050
1051
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1052
  API_END();
1053
1054
}

Guolin Ke's avatar
Guolin Ke committed
1055
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1056
                              int* out) {
1057
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1058
1059
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1060
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1061
}
1062

1063
1064
1065
1066
1067
1068
1069
1070
1071
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1072
1073
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1074
int LGBM_BoosterCreate(const DatasetHandle train_data,
1075
1076
                       const char* parameters,
                       BoosterHandle* out) {
1077
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1078
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1079
1080
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1081
  API_END();
1082
1083
}

Guolin Ke's avatar
Guolin Ke committed
1084
int LGBM_BoosterCreateFromModelfile(
1085
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1086
  int* out_num_iterations,
1087
  BoosterHandle* out) {
1088
  API_BEGIN();
wxchan's avatar
wxchan committed
1089
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1090
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1091
  *out = ret.release();
1092
  API_END();
1093
1094
}

Guolin Ke's avatar
Guolin Ke committed
1095
int LGBM_BoosterLoadModelFromString(
1096
1097
1098
1099
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1100
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1101
1102
1103
1104
1105
1106
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1107
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1108
int LGBM_BoosterFree(BoosterHandle handle) {
1109
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1110
  delete reinterpret_cast<Booster*>(handle);
1111
  API_END();
1112
1113
}

1114
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1115
1116
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1117
  ref_booster->ShuffleModels(start_iter, end_iter);
1118
1119
1120
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1121
int LGBM_BoosterMerge(BoosterHandle handle,
1122
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1123
1124
1125
1126
1127
1128
1129
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1130
int LGBM_BoosterAddValidData(BoosterHandle handle,
1131
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1132
1133
1134
1135
1136
1137
1138
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1139
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1140
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1141
1142
1143
1144
1145
1146
1147
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1148
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1149
1150
1151
1152
1153
1154
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1155
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1156
1157
1158
1159
1160
1161
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1162
1163
1164
1165
1166
1167
1168
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1169
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1170
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1171
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1172
1173
1174
1175
1176
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1177
  API_END();
1178
1179
}

Guolin Ke's avatar
Guolin Ke committed
1180
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1181
1182
1183
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1184
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1185
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1186
  #ifdef SCORE_T_USE_DOUBLE
1187
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1188
  #else
1189
1190
1191
1192
1193
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1194
  #endif
1195
  API_END();
1196
1197
}

Guolin Ke's avatar
Guolin Ke committed
1198
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1199
1200
1201
1202
1203
1204
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1205
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1206
1207
1208
1209
1210
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1211

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1226
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1227
1228
1229
1230
1231
1232
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1233
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1234
1235
1236
1237
1238
1239
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1240
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1241
1242
1243
1244
1245
1246
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1247
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1248
1249
1250
1251
1252
1253
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1254
int LGBM_BoosterGetEval(BoosterHandle handle,
1255
1256
1257
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1258
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1259
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1260
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1261
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1262
  *out_len = static_cast<int>(result_buf.size());
1263
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1264
    (out_results)[i] = static_cast<double>(result_buf[i]);
1265
  }
1266
  API_END();
1267
1268
}

Guolin Ke's avatar
Guolin Ke committed
1269
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1270
1271
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1272
1273
1274
1275
1276
1277
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1278
int LGBM_BoosterGetPredict(BoosterHandle handle,
1279
1280
1281
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1282
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1283
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1284
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1285
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1286
1287
}

Guolin Ke's avatar
Guolin Ke committed
1288
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1289
1290
1291
1292
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1293
                               const char* parameter,
1294
                               const char* result_filename) {
1295
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1296
1297
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1298
1299
1300
1301
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1302
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1303
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1304
                       config, result_filename);
1305
  API_END();
1306
1307
}

Guolin Ke's avatar
Guolin Ke committed
1308
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1309
1310
1311
1312
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1313
1314
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1315
1316
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1317
1318
1319
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1320
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1321
1322
1323
1324
1325
1326
1327
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
1328
                              int64_t num_col,
1329
1330
                              int predict_type,
                              int num_iteration,
1331
                              const char* parameter,
1332
1333
                              int64_t* out_len,
                              double* out_result) {
1334
  API_BEGIN();
1335
1336
1337
1338
1339
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1340
1341
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1342
1343
1344
1345
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1346
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1347
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1348
  int nrow = static_cast<int>(nindptr - 1);
1349
  ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1350
                       config, out_result, out_len);
1351
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1352
}
1353

1354
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1355
1356
1357
1358
1359
1360
1361
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
1362
                                       int64_t num_col,
1363
1364
1365
1366
1367
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1368
  API_BEGIN();
1369
1370
1371
1372
1373
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1374
1375
1376
1377
1378
1379
1380
1381
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1382
  ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
1383
1384
1385
1386
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1387
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1398
                              const char* parameter,
1399
1400
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1401
1402
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1403
1404
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1415
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1416
1417
1418
1419
1420
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1421
1422
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1423
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1424
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1425
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1426
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1427
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1428
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1429
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1430
1431
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1432
1433
    return one_row;
  };
1434
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
cbecker's avatar
cbecker committed
1435
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1436
1437
1438
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1439
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1440
1441
1442
1443
1444
1445
1446
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1447
                              const char* parameter,
1448
1449
                              int64_t* out_len,
                              double* out_result) {
1450
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1451
1452
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1453
1454
1455
1456
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1457
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1458
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1459
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1460
                       config, out_result, out_len);
1461
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1462
}
1463

1464
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1465
1466
1467
1468
1469
1470
1471
1472
1473
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1474
1475
1476
1477
1478
1479
1480
1481
1482
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1483
  ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
1484
1485
1486
1487
  API_END();
}


1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1507
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
1508
1509
1510
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1511
int LGBM_BoosterSaveModel(BoosterHandle handle,
1512
                          int start_iteration,
1513
1514
                          int num_iteration,
                          const char* filename) {
1515
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1516
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1517
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1518
1519
1520
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1521
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1522
                                  int start_iteration,
1523
                                  int num_iteration,
1524
                                  int64_t buffer_len,
1525
                                  int64_t* out_len,
1526
                                  char* out_str) {
1527
1528
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1529
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1530
  *out_len = static_cast<int64_t>(model.size()) + 1;
1531
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1532
    std::memcpy(out_str, model.c_str(), *out_len);
1533
1534
1535
1536
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1537
int LGBM_BoosterDumpModel(BoosterHandle handle,
1538
                          int start_iteration,
1539
                          int num_iteration,
1540
1541
                          int64_t buffer_len,
                          int64_t* out_len,
1542
                          char* out_str) {
wxchan's avatar
wxchan committed
1543
1544
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1545
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1546
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1547
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1548
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1549
  }
1550
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1551
}
1552

Guolin Ke's avatar
Guolin Ke committed
1553
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1554
1555
1556
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1557
1558
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1559
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1560
1561
1562
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1563
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1564
1565
1566
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1567
1568
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1569
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1570
1571
1572
  API_END();
}

1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1586
1587
1588
1589
1590
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1591
  Config config;
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1608
1609
1610
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1611
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1612
  if (num_machines > 1) {
1613
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1614
1615
1616
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1617

Guolin Ke's avatar
Guolin Ke committed
1618
// ---- start of some help functions
1619
1620
1621

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1622
  if (data_type == C_API_DTYPE_FLOAT32) {
1623
1624
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1625
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1626
        std::vector<double> ret(num_col);
1627
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1628
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1629
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1630
1631
1632
1633
        }
        return ret;
      };
    } else {
1634
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1635
        std::vector<double> ret(num_col);
1636
        for (int i = 0; i < num_col; ++i) {
1637
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1638
1639
1640
1641
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1642
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1643
1644
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1645
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1646
        std::vector<double> ret(num_col);
1647
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1648
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1649
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1650
1651
1652
1653
        }
        return ret;
      };
    } else {
1654
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1655
        std::vector<double> ret(num_col);
1656
        for (int i = 0; i < num_col; ++i) {
1657
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1658
1659
1660
1661
1662
        }
        return ret;
      };
    }
  }
1663
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1664
1665
1666
1667
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1668
1669
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1670
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1671
1672
1673
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1674
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1675
          ret.emplace_back(i, raw_values[i]);
1676
        }
Guolin Ke's avatar
Guolin Ke committed
1677
1678
1679
      }
      return ret;
    };
1680
  }
Guolin Ke's avatar
Guolin Ke committed
1681
  return nullptr;
1682
1683
}

1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1700
std::function<std::vector<std::pair<int, double>>(int idx)>
1701
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1702
  if (data_type == C_API_DTYPE_FLOAT32) {
1703
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1704
    if (indptr_type == C_API_DTYPE_INT32) {
1705
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1706
      return [=] (int idx) {
1707
1708
1709
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1710
1711
1712
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1713
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1714
          ret.emplace_back(indices[i], data_ptr[i]);
1715
1716
1717
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1718
    } else if (indptr_type == C_API_DTYPE_INT64) {
1719
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1720
      return [=] (int idx) {
1721
1722
1723
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1724
1725
1726
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1727
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1728
          ret.emplace_back(indices[i], data_ptr[i]);
1729
1730
1731
1732
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1733
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1734
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1735
    if (indptr_type == C_API_DTYPE_INT32) {
1736
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1737
      return [=] (int idx) {
1738
1739
1740
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1741
1742
1743
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1744
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1745
          ret.emplace_back(indices[i], data_ptr[i]);
1746
1747
1748
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1749
    } else if (indptr_type == C_API_DTYPE_INT64) {
1750
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1751
      return [=] (int idx) {
1752
1753
1754
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1755
1756
1757
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1758
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1759
          ret.emplace_back(indices[i], data_ptr[i]);
1760
1761
1762
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1763
1764
    }
  }
1765
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1766
1767
}

Guolin Ke's avatar
Guolin Ke committed
1768
std::function<std::pair<int, double>(int idx)>
1769
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1770
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1771
  if (data_type == C_API_DTYPE_FLOAT32) {
1772
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1773
    if (col_ptr_type == C_API_DTYPE_INT32) {
1774
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1775
1776
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1777
1778
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1779
1780
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1781
        }
Guolin Ke's avatar
Guolin Ke committed
1782
1783
1784
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1785
      };
Guolin Ke's avatar
Guolin Ke committed
1786
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1787
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1788
1789
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1790
1791
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1792
1793
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1794
        }
Guolin Ke's avatar
Guolin Ke committed
1795
1796
1797
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1798
      };
Guolin Ke's avatar
Guolin Ke committed
1799
    }
Guolin Ke's avatar
Guolin Ke committed
1800
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1801
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1802
    if (col_ptr_type == C_API_DTYPE_INT32) {
1803
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1804
1805
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1806
1807
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1808
1809
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1810
        }
Guolin Ke's avatar
Guolin Ke committed
1811
1812
1813
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1814
      };
Guolin Ke's avatar
Guolin Ke committed
1815
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1816
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1817
1818
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1819
1820
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1821
1822
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1823
        }
Guolin Ke's avatar
Guolin Ke committed
1824
1825
1826
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1827
      };
Guolin Ke's avatar
Guolin Ke committed
1828
1829
    }
  }
1830
  throw std::runtime_error("Unknown data type in CSC matrix");
1831
1832
}

Guolin Ke's avatar
Guolin Ke committed
1833
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1834
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1835
1836
1837
1838
1839
1840
1841
1842
1843
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1844
    }
Guolin Ke's avatar
Guolin Ke committed
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1862
    }
Guolin Ke's avatar
Guolin Ke committed
1863
1864
1865
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1866
  }
Guolin Ke's avatar
Guolin Ke committed
1867
}