c_api.cpp 37.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

Guolin Ke's avatar
Guolin Ke committed
34
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
35
36
37
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
38
39
40
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
41
42
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
43
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
44
45
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
46

wxchan's avatar
wxchan committed
47
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
48

Guolin Ke's avatar
Guolin Ke committed
49
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
50
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
51
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
52
53

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
54
55
56
57
58
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
59
60
61
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
62

Guolin Ke's avatar
Guolin Ke committed
63
  }
64

wxchan's avatar
wxchan committed
65
66
67
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
90
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
91
92
93
94
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
95
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
96
97
98
99
100
101
102
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
106
107

    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
117
118
119
120
121
122

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
123
    }
Guolin Ke's avatar
Guolin Ke committed
124
125
126

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
127

wxchan's avatar
wxchan committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
143

144
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
145
    std::lock_guard<std::mutex> lock(mutex_);
146
147
148
149
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
150
    std::lock_guard<std::mutex> lock(mutex_);
151
152
153
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
154
155
156
157
158
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
159
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
160
161
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
162
163
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
164
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
165
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
166
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
167
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
168
169
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
170
    }
Guolin Ke's avatar
Guolin Ke committed
171
172
173
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
174
175
  }

Guolin Ke's avatar
Guolin Ke committed
176
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
177
178
179
180
181
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
182
  }
183

wxchan's avatar
wxchan committed
184
185
186
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
187

Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
193
194
195
196
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
197
198
199
200
201
202
203
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
204

wxchan's avatar
wxchan committed
205
206
207
208
209
210
211
212
213
214
215
216
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
217

Guolin Ke's avatar
Guolin Ke committed
218
private:
219

wxchan's avatar
wxchan committed
220
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
221
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
222
223
224
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
225
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
226
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
227
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
228
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
229
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
230
231
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
232
233
234
};

}
Guolin Ke's avatar
Guolin Ke committed
235
236
237

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
270
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
271
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

wxchan's avatar
wxchan committed
274
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
275
  const char* parameters,
276
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
277
  DatasetHandle* out) {
278
  API_BEGIN();
wxchan's avatar
wxchan committed
279
280
281
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
282
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
283
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
284
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
285
  } else {
Guolin Ke's avatar
Guolin Ke committed
286
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
287
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
288
  }
289
  API_END();
Guolin Ke's avatar
Guolin Ke committed
290
291
}

wxchan's avatar
wxchan committed
292
DllExport int LGBM_DatasetCreateFromMat(const void* data,
293
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
298
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
299
  DatasetHandle* out) {
300
  API_BEGIN();
wxchan's avatar
wxchan committed
301
302
303
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
304
  std::unique_ptr<Dataset> ret;
305
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
306
307
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
308
309
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
310
    auto sample_indices = rand.Sample(nrow, sample_cnt);
311
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
312
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
313
      auto idx = sample_indices[i];
314
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
315
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
316
317
318
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
319
320
      }
    }
321
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
322
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
323
  } else {
324
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
325
    ret->CopyFeatureMapperFrom(
326
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
327
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
333
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
334
335
336
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
337
  *out = ret.release();
338
  API_END();
339
340
}

wxchan's avatar
wxchan committed
341
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
342
  int indptr_type,
343
344
  const int32_t* indices,
  const void* data,
345
346
347
348
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
349
  const char* parameters,
350
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
351
  DatasetHandle* out) {
352
  API_BEGIN();
wxchan's avatar
wxchan committed
353
354
355
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
356
  std::unique_ptr<Dataset> ret;
357
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
358
359
360
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
361
362
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
363
364
365
366
367
368
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
369
370
371
372
373
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
374
          }
375
376
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
377
378
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
379
380
381
        }
      }
    }
382
    CHECK(num_col >= static_cast<int>(sample_values.size()));
383
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
384
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
385
  } else {
386
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
387
    ret->CopyFeatureMapperFrom(
388
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
389
      io_config.is_enable_sparse);
390
391
392
393
394
395
396
397
398
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
399
  *out = ret.release();
400
  API_END();
401
402
}

wxchan's avatar
wxchan committed
403
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
404
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
405
406
  const int32_t* indices,
  const void* data,
407
408
409
410
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
411
  const char* parameters,
412
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
413
  DatasetHandle* out) {
414
  API_BEGIN();
wxchan's avatar
wxchan committed
415
416
417
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
418
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
419
420
421
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
422
423
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
428
429
430
431
432
433
434
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
435
    }
436
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
437
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
438
  } else {
439
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
440
    ret->CopyFeatureMapperFrom(
441
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
442
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
443
444
445
446
447
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
453
454
455
456
457
458
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
459
460
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
461
  *out = ret.release();
462
  API_END();
Guolin Ke's avatar
Guolin Ke committed
463
464
}

wxchan's avatar
wxchan committed
465
DllExport int LGBM_DatasetGetSubset(
466
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
467
468
469
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
470
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
471
472
473
474
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
475
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
476
477
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
478
      num_used_row_indices,
wxchan's avatar
wxchan committed
479
480
481
482
483
484
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
485
DllExport int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
486
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
487
488
489
490
491
492
493
494
495
496
497
498
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
typo  
Guolin Ke committed
499
DllExport int LGBM_DatasetFree(DatasetHandle handle) {
500
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
501
  delete reinterpret_cast<Dataset*>(handle);
502
  API_END();
503
504
}

Guolin Ke's avatar
typo  
Guolin Ke committed
505
DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
506
  const char* filename) {
507
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
508
509
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
510
  API_END();
511
512
}

Guolin Ke's avatar
typo  
Guolin Ke committed
513
DllExport int LGBM_DatasetSetField(DatasetHandle handle,
514
515
  const char* field_name,
  const void* field_data,
516
  int64_t num_element,
517
  int type) {
518
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
519
  auto dataset = reinterpret_cast<Dataset*>(handle);
520
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
521
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
522
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
523
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
524
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
525
  }
526
527
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
528
529
}

Guolin Ke's avatar
typo  
Guolin Ke committed
530
DllExport int LGBM_DatasetGetField(DatasetHandle handle,
531
  const char* field_name,
532
  int64_t* out_len,
533
534
  const void** out_ptr,
  int* out_type) {
535
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
536
  auto dataset = reinterpret_cast<Dataset*>(handle);
537
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
538
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
539
    *out_type = C_API_DTYPE_FLOAT32;
540
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
541
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
542
    *out_type = C_API_DTYPE_INT32;
543
    is_success = true;
544
  }
545
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
546
  if (*out_ptr == nullptr) { *out_len = 0; }
547
  API_END();
548
549
}

Guolin Ke's avatar
typo  
Guolin Ke committed
550
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
551
  int64_t* out) {
552
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
553
554
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
555
  API_END();
556
557
}

Guolin Ke's avatar
typo  
Guolin Ke committed
558
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
559
  int64_t* out) {
560
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
561
562
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
563
  API_END();
Guolin Ke's avatar
Guolin Ke committed
564
}
565
566
567

// ---- start of booster

Guolin Ke's avatar
typo  
Guolin Ke committed
568
DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
569
570
  const char* parameters,
  BoosterHandle* out) {
571
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
572
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
573
574
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
575
  API_END();
576
577
}

wxchan's avatar
wxchan committed
578
DllExport int LGBM_BoosterCreateFromModelfile(
579
  const char* filename,
wxchan's avatar
wxchan committed
580
  int64_t* out_num_iterations,
581
  BoosterHandle* out) {
582
  API_BEGIN();
wxchan's avatar
wxchan committed
583
584
585
586
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
587
  API_END();
588
589
590
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
591
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
592
  delete reinterpret_cast<Booster*>(handle);
593
  API_END();
594
595
}

wxchan's avatar
wxchan committed
596
597
598
599
600
601
602
603
604
605
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
606
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
607
608
609
610
611
612
613
614
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
615
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

637
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
638
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
639
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
640
641
642
643
644
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
645
  API_END();
646
647
648
649
650
651
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
652
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
653
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
654
655
656
657
658
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
659
  API_END();
660
661
}

wxchan's avatar
wxchan committed
662
663
664
665
666
667
668
669
670
671
672
673
674
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
675

wxchan's avatar
wxchan committed
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
692
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
693
  double* out_results) {
694
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
695
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
696
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
697
  auto result_buf = boosting->GetEvalAt(data_idx);
698
  *out_len = static_cast<int64_t>(result_buf.size());
699
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
700
    (out_results)[i] = static_cast<double>(result_buf[i]);
701
  }
702
  API_END();
703
704
}

Guolin Ke's avatar
Guolin Ke committed
705
706
707
708
709
710
711
712
713
DllExport int LGBM_BoosterGetNumPredict(BoosterHandle handle,
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
714
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
715
  int data_idx,
716
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
717
  double* out_result) {
718
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
719
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
720
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
721
  API_END();
Guolin Ke's avatar
Guolin Ke committed
722
723
}

724
725
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
726
727
728
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
729
  const char* result_filename) {
730
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
731
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
732
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
733
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
734
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
735
  API_END();
736
737
}

Guolin Ke's avatar
Guolin Ke committed
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
int GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

DllExport int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
  int64_t num_row,
  int predict_type,
  int64_t num_iteration,
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

762
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
763
764
  const void* indptr,
  int indptr_type,
765
766
  const int32_t* indices,
  const void* data,
767
768
769
770
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
771
  int predict_type,
wxchan's avatar
wxchan committed
772
773
  int64_t num_iteration,
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
774
  double* out_result) {
775
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
776
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
777
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
778
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
779
  int num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
780
781
782
783
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
784
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
785
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
786
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
787
788
    }
  }
wxchan's avatar
wxchan committed
789
  *out_len = nrow * num_preb_in_one_row;
790
  API_END();
Guolin Ke's avatar
Guolin Ke committed
791
}
792

Guolin Ke's avatar
Guolin Ke committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
  int64_t num_iteration,
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
  int num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

838
839
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
840
  int data_type,
841
842
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
843
  int is_row_major,
844
  int predict_type,
wxchan's avatar
wxchan committed
845
846
  int64_t num_iteration,
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
847
  double* out_result) {
848
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
849
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
850
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
851
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
852
  int num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
853
854
855
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
856
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
857
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
858
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
859
860
    }
  }
wxchan's avatar
wxchan committed
861
  *out_len = nrow * num_preb_in_one_row;
862
  API_END();
Guolin Ke's avatar
Guolin Ke committed
863
}
864
865

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
866
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
867
  const char* filename) {
868
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
869
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
870
871
872
873
874
875
876
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
877
  char* out_str) {
wxchan's avatar
wxchan committed
878
879
880
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
881
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
882
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
883
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
884
  }
885
  API_END();
Guolin Ke's avatar
Guolin Ke committed
886
}
887

Guolin Ke's avatar
Guolin Ke committed
888
889
890
DllExport int LGBM_BoosterGetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
891
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
892
893
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
894
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
895
896
897
898
899
900
  API_END();
}

DllExport int LGBM_BoosterSetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
901
  double val) {
Guolin Ke's avatar
Guolin Ke committed
902
903
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
904
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
905
906
907
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
908
// ---- start of some help functions
909
910
911

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
912
  if (data_type == C_API_DTYPE_FLOAT32) {
913
914
915
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
916
        std::vector<double> ret(num_col);
917
918
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
919
          ret[i] = static_cast<double>(*(tmp_ptr + i));
920
921
922
923
924
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
925
        std::vector<double> ret(num_col);
926
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
927
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
928
929
930
931
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
932
  } else if (data_type == C_API_DTYPE_FLOAT64) {
933
934
935
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
936
        std::vector<double> ret(num_col);
937
938
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
939
          ret[i] = static_cast<double>(*(tmp_ptr + i));
940
941
942
943
944
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
945
        std::vector<double> ret(num_col);
946
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
947
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
948
949
950
951
952
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
953
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
954
955
956
957
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
958
959
960
961
962
963
964
965
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
966
        }
Guolin Ke's avatar
Guolin Ke committed
967
968
969
      }
      return ret;
    };
970
  }
Guolin Ke's avatar
Guolin Ke committed
971
  return nullptr;
972
973
974
975
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
976
  if (data_type == C_API_DTYPE_FLOAT32) {
977
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
978
    if (indptr_type == C_API_DTYPE_INT32) {
979
980
981
982
983
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
984
        for (int64_t i = start; i < end; ++i) {
985
986
987
988
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
989
    } else if (indptr_type == C_API_DTYPE_INT64) {
990
991
992
993
994
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
995
        for (int64_t i = start; i < end; ++i) {
996
997
998
999
1000
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1001
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1002
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1003
    if (indptr_type == C_API_DTYPE_INT32) {
1004
1005
1006
1007
1008
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1009
        for (int64_t i = start; i < end; ++i) {
1010
1011
1012
1013
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1014
    } else if (indptr_type == C_API_DTYPE_INT64) {
1015
1016
1017
1018
1019
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1020
        for (int64_t i = start; i < end; ++i) {
1021
1022
1023
1024
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1025
1026
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1027
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1028
1029
}

Guolin Ke's avatar
Guolin Ke committed
1030
1031
1032
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1033
  if (data_type == C_API_DTYPE_FLOAT32) {
1034
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1035
    if (col_ptr_type == C_API_DTYPE_INT32) {
1036
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1037
1038
1039
1040
1041
1042
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1043
        }
Guolin Ke's avatar
Guolin Ke committed
1044
1045
1046
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1047
      };
Guolin Ke's avatar
Guolin Ke committed
1048
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1049
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1050
1051
1052
1053
1054
1055
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1056
        }
Guolin Ke's avatar
Guolin Ke committed
1057
1058
1059
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1060
      };
Guolin Ke's avatar
Guolin Ke committed
1061
    }
Guolin Ke's avatar
Guolin Ke committed
1062
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1063
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1064
    if (col_ptr_type == C_API_DTYPE_INT32) {
1065
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1066
1067
1068
1069
1070
1071
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1072
        }
Guolin Ke's avatar
Guolin Ke committed
1073
1074
1075
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1076
      };
Guolin Ke's avatar
Guolin Ke committed
1077
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1078
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1079
1080
1081
1082
1083
1084
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1085
        }
Guolin Ke's avatar
Guolin Ke committed
1086
1087
1088
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1089
      };
Guolin Ke's avatar
Guolin Ke committed
1090
1091
1092
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1093
1094
}

Guolin Ke's avatar
Guolin Ke committed
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1106
    }
Guolin Ke's avatar
Guolin Ke committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1124
    }
Guolin Ke's avatar
Guolin Ke committed
1125
1126
1127
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1128
  }
Guolin Ke's avatar
Guolin Ke committed
1129
}