c_api.cpp 37.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

Guolin Ke's avatar
Guolin Ke committed
34
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
35
36
37
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
38
39
40
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
41
42
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
43
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
44
45
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
46

wxchan's avatar
wxchan committed
47
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
48

Guolin Ke's avatar
Guolin Ke committed
49
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
50
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
51
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
52
53

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
54
55
56
57
58
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
59
60
61
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
62

Guolin Ke's avatar
Guolin Ke committed
63
  }
64

wxchan's avatar
wxchan committed
65
66
67
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
90
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
91
92
93
94
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
95
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
96
97
98
99
100
101
102
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
106
107

    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
117
118
119
120
121
122

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
123
    }
Guolin Ke's avatar
Guolin Ke committed
124
125
126

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
127

wxchan's avatar
wxchan committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
143

144
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
145
    std::lock_guard<std::mutex> lock(mutex_);
146
147
148
149
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
150
    std::lock_guard<std::mutex> lock(mutex_);
151
152
153
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
154
155
156
157
158
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
159
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
160
161
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
162
163
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
164
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
165
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
166
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
167
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
168
169
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
170
    }
Guolin Ke's avatar
Guolin Ke committed
171
172
173
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
174
175
  }

Guolin Ke's avatar
Guolin Ke committed
176
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
177
178
179
180
181
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
182
  }
183

184
185
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
186
  }
187

Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
193
194
195
196
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
197
198
199
200
201
202
203
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
204

wxchan's avatar
wxchan committed
205
206
207
208
209
210
211
212
213
214
215
216
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
217

Guolin Ke's avatar
Guolin Ke committed
218
private:
219

wxchan's avatar
wxchan committed
220
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
221
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
222
223
224
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
225
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
226
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
227
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
228
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
229
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
230
231
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
232
233
234
};

}
Guolin Ke's avatar
Guolin Ke committed
235
236
237

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
270
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
271
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

wxchan's avatar
wxchan committed
274
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
275
  const char* parameters,
276
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
277
  DatasetHandle* out) {
278
  API_BEGIN();
wxchan's avatar
wxchan committed
279
280
281
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
282
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
283
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
284
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
285
  } else {
Guolin Ke's avatar
Guolin Ke committed
286
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
287
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
288
  }
289
  API_END();
Guolin Ke's avatar
Guolin Ke committed
290
291
}

wxchan's avatar
wxchan committed
292
DllExport int LGBM_DatasetCreateFromMat(const void* data,
293
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
298
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
299
  DatasetHandle* out) {
300
  API_BEGIN();
wxchan's avatar
wxchan committed
301
302
303
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
304
  std::unique_ptr<Dataset> ret;
305
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
306
307
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
308
309
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
310
    auto sample_indices = rand.Sample(nrow, sample_cnt);
311
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
312
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
313
      auto idx = sample_indices[i];
314
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
315
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
316
317
318
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
319
320
      }
    }
321
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
322
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
323
  } else {
324
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
325
    ret->CopyFeatureMapperFrom(
326
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
327
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
333
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
334
335
336
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
337
  *out = ret.release();
338
  API_END();
339
340
}

wxchan's avatar
wxchan committed
341
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
342
  int indptr_type,
343
344
  const int32_t* indices,
  const void* data,
345
346
347
348
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
349
  const char* parameters,
350
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
351
  DatasetHandle* out) {
352
  API_BEGIN();
wxchan's avatar
wxchan committed
353
354
355
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
356
  std::unique_ptr<Dataset> ret;
357
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
358
359
360
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
361
362
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
363
364
365
366
367
368
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
369
370
371
372
373
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
374
          }
375
376
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
377
378
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
379
380
381
        }
      }
    }
382
    CHECK(num_col >= static_cast<int>(sample_values.size()));
383
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
384
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
385
  } else {
386
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
387
    ret->CopyFeatureMapperFrom(
388
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
389
      io_config.is_enable_sparse);
390
391
392
393
394
395
396
397
398
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
399
  *out = ret.release();
400
  API_END();
401
402
}

wxchan's avatar
wxchan committed
403
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
404
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
405
406
  const int32_t* indices,
  const void* data,
407
408
409
410
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
411
  const char* parameters,
412
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
413
  DatasetHandle* out) {
414
  API_BEGIN();
wxchan's avatar
wxchan committed
415
416
417
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
418
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
419
420
421
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
422
423
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
428
429
430
431
432
433
434
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
435
    }
436
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
437
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
438
  } else {
439
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
440
    ret->CopyFeatureMapperFrom(
441
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
442
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
443
444
445
446
447
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
453
454
455
456
457
458
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
459
460
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
461
  *out = ret.release();
462
  API_END();
Guolin Ke's avatar
Guolin Ke committed
463
464
}

wxchan's avatar
wxchan committed
465
DllExport int LGBM_DatasetGetSubset(
466
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
467
468
469
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
470
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
471
472
473
474
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
475
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
476
477
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
478
      num_used_row_indices,
wxchan's avatar
wxchan committed
479
480
481
482
483
484
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
485
DllExport int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
486
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
487
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
488
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
489
490
491
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
492
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
493
494
495
496
497
498
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

499
500
501
DllExport int LGBM_DatasetGetFeatureNames(
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
502
  int* num_feature_names) {
503
504
505
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
506
507
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
508
509
510
511
512
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

Guolin Ke's avatar
typo  
Guolin Ke committed
513
DllExport int LGBM_DatasetFree(DatasetHandle handle) {
514
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
515
  delete reinterpret_cast<Dataset*>(handle);
516
  API_END();
517
518
}

Guolin Ke's avatar
typo  
Guolin Ke committed
519
DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
520
  const char* filename) {
521
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
522
523
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
524
  API_END();
525
526
}

Guolin Ke's avatar
typo  
Guolin Ke committed
527
DllExport int LGBM_DatasetSetField(DatasetHandle handle,
528
529
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
530
  int num_element,
531
  int type) {
532
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
533
  auto dataset = reinterpret_cast<Dataset*>(handle);
534
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
535
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
536
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
537
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
538
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
539
  }
540
541
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
542
543
}

Guolin Ke's avatar
typo  
Guolin Ke committed
544
DllExport int LGBM_DatasetGetField(DatasetHandle handle,
545
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
546
  int* out_len,
547
548
  const void** out_ptr,
  int* out_type) {
549
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
550
  auto dataset = reinterpret_cast<Dataset*>(handle);
551
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
552
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
553
    *out_type = C_API_DTYPE_FLOAT32;
554
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
555
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
556
    *out_type = C_API_DTYPE_INT32;
557
    is_success = true;
558
  }
559
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
560
  if (*out_ptr == nullptr) { *out_len = 0; }
561
  API_END();
562
563
}

Guolin Ke's avatar
typo  
Guolin Ke committed
564
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
565
  int* out) {
566
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
567
568
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
569
  API_END();
570
571
}

Guolin Ke's avatar
typo  
Guolin Ke committed
572
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
573
  int* out) {
574
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
575
576
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
577
  API_END();
Guolin Ke's avatar
Guolin Ke committed
578
}
579
580
581

// ---- start of booster

Guolin Ke's avatar
typo  
Guolin Ke committed
582
DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
583
584
  const char* parameters,
  BoosterHandle* out) {
585
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
586
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
587
588
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
589
  API_END();
590
591
}

wxchan's avatar
wxchan committed
592
DllExport int LGBM_BoosterCreateFromModelfile(
593
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
594
  int* out_num_iterations,
595
  BoosterHandle* out) {
596
  API_BEGIN();
wxchan's avatar
wxchan committed
597
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
598
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
599
  *out = ret.release();
600
  API_END();
601
602
603
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
604
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
605
  delete reinterpret_cast<Booster*>(handle);
606
  API_END();
607
608
}

wxchan's avatar
wxchan committed
609
610
611
612
613
614
615
616
617
618
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
619
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
620
621
622
623
624
625
626
627
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
628
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
643
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
644
645
646
647
648
649
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

650
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
651
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
652
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
653
654
655
656
657
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
658
  API_END();
659
660
661
662
663
664
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
665
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
666
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
667
668
669
670
671
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
672
  API_END();
673
674
}

wxchan's avatar
wxchan committed
675
676
677
678
679
680
681
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
682
DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
683
684
685
686
687
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
688

Guolin Ke's avatar
Guolin Ke committed
689
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
690
691
692
693
694
695
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
696
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
697
698
699
700
701
702
703
704
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
705
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
706
  double* out_results) {
707
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
708
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
709
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
710
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
711
  *out_len = static_cast<int>(result_buf.size());
712
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
713
    (out_results)[i] = static_cast<double>(result_buf[i]);
714
  }
715
  API_END();
716
717
}

Guolin Ke's avatar
Guolin Ke committed
718
719
720
721
722
723
724
725
726
DllExport int LGBM_BoosterGetNumPredict(BoosterHandle handle,
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
727
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
728
  int data_idx,
729
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
730
  double* out_result) {
731
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
732
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
733
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
734
  API_END();
Guolin Ke's avatar
Guolin Ke committed
735
736
}

737
738
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
739
740
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
741
  int num_iteration,
742
  const char* result_filename) {
743
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
744
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
745
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
746
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
747
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
748
  API_END();
749
750
}

751
752
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
753
754
755
756
757
758
759
760
761
762
763
764
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

DllExport int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
765
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
766
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
767
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
768
769
770
771
772
773
774
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

775
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
776
777
  const void* indptr,
  int indptr_type,
778
779
  const int32_t* indices,
  const void* data,
780
781
782
783
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
784
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
785
  int num_iteration,
wxchan's avatar
wxchan committed
786
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
787
  double* out_result) {
788
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
789
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
790
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
791
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
792
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
793
794
795
796
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
797
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
798
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
799
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
800
801
    }
  }
wxchan's avatar
wxchan committed
802
  *out_len = nrow * num_preb_in_one_row;
803
  API_END();
Guolin Ke's avatar
Guolin Ke committed
804
}
805

Guolin Ke's avatar
Guolin Ke committed
806
807
808
809
810
811
812
813
814
815
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
816
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
817
818
819
820
821
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
822
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

851
852
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
853
  int data_type,
854
855
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
856
  int is_row_major,
857
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
858
  int num_iteration,
wxchan's avatar
wxchan committed
859
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
860
  double* out_result) {
861
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
862
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
863
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
864
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
865
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
866
867
868
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
869
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
870
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
871
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
872
873
    }
  }
wxchan's avatar
wxchan committed
874
  *out_len = nrow * num_preb_in_one_row;
875
  API_END();
Guolin Ke's avatar
Guolin Ke committed
876
}
877
878

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
879
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
880
  const char* filename) {
881
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
882
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
883
884
885
886
887
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
888
  int num_iteration,
wxchan's avatar
wxchan committed
889
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
890
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
891
  char* out_str) {
wxchan's avatar
wxchan committed
892
893
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
894
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
895
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
896
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
897
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
898
  }
899
  API_END();
Guolin Ke's avatar
Guolin Ke committed
900
}
901

Guolin Ke's avatar
Guolin Ke committed
902
903
904
DllExport int LGBM_BoosterGetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
905
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
906
907
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
908
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
909
910
911
912
913
914
  API_END();
}

DllExport int LGBM_BoosterSetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
915
  double val) {
Guolin Ke's avatar
Guolin Ke committed
916
917
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
918
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
919
920
921
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
922
// ---- start of some help functions
923
924
925

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
926
  if (data_type == C_API_DTYPE_FLOAT32) {
927
928
929
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
930
        std::vector<double> ret(num_col);
931
932
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
933
          ret[i] = static_cast<double>(*(tmp_ptr + i));
934
935
936
937
938
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
939
        std::vector<double> ret(num_col);
940
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
941
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
942
943
944
945
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
946
  } else if (data_type == C_API_DTYPE_FLOAT64) {
947
948
949
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
950
        std::vector<double> ret(num_col);
951
952
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
953
          ret[i] = static_cast<double>(*(tmp_ptr + i));
954
955
956
957
958
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
959
        std::vector<double> ret(num_col);
960
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
961
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
962
963
964
965
966
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
967
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
968
969
970
971
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
972
973
974
975
976
977
978
979
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
980
        }
Guolin Ke's avatar
Guolin Ke committed
981
982
983
      }
      return ret;
    };
984
  }
Guolin Ke's avatar
Guolin Ke committed
985
  return nullptr;
986
987
988
989
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
990
  if (data_type == C_API_DTYPE_FLOAT32) {
991
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
992
    if (indptr_type == C_API_DTYPE_INT32) {
993
994
995
996
997
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
998
        for (int64_t i = start; i < end; ++i) {
999
1000
1001
1002
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1003
    } else if (indptr_type == C_API_DTYPE_INT64) {
1004
1005
1006
1007
1008
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1009
        for (int64_t i = start; i < end; ++i) {
1010
1011
1012
1013
1014
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1015
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1016
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1017
    if (indptr_type == C_API_DTYPE_INT32) {
1018
1019
1020
1021
1022
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1023
        for (int64_t i = start; i < end; ++i) {
1024
1025
1026
1027
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1028
    } else if (indptr_type == C_API_DTYPE_INT64) {
1029
1030
1031
1032
1033
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1034
        for (int64_t i = start; i < end; ++i) {
1035
1036
1037
1038
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1039
1040
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1041
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1042
1043
}

Guolin Ke's avatar
Guolin Ke committed
1044
1045
1046
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1047
  if (data_type == C_API_DTYPE_FLOAT32) {
1048
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1049
    if (col_ptr_type == C_API_DTYPE_INT32) {
1050
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1051
1052
1053
1054
1055
1056
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1057
        }
Guolin Ke's avatar
Guolin Ke committed
1058
1059
1060
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1061
      };
Guolin Ke's avatar
Guolin Ke committed
1062
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1063
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1064
1065
1066
1067
1068
1069
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1070
        }
Guolin Ke's avatar
Guolin Ke committed
1071
1072
1073
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1074
      };
Guolin Ke's avatar
Guolin Ke committed
1075
    }
Guolin Ke's avatar
Guolin Ke committed
1076
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1077
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1078
    if (col_ptr_type == C_API_DTYPE_INT32) {
1079
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1080
1081
1082
1083
1084
1085
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1086
        }
Guolin Ke's avatar
Guolin Ke committed
1087
1088
1089
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1090
      };
Guolin Ke's avatar
Guolin Ke committed
1091
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1092
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1093
1094
1095
1096
1097
1098
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1099
        }
Guolin Ke's avatar
Guolin Ke committed
1100
1101
1102
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1103
      };
Guolin Ke's avatar
Guolin Ke committed
1104
1105
1106
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1107
1108
}

Guolin Ke's avatar
Guolin Ke committed
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1120
    }
Guolin Ke's avatar
Guolin Ke committed
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1138
    }
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1142
  }
Guolin Ke's avatar
Guolin Ke committed
1143
}