c_api.cpp 26.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19

Guolin Ke's avatar
Guolin Ke committed
20
21
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
26
27
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
  }

  Booster(const Dataset* train_data, 
    std::vector<const Dataset*> valid_data, 
    std::vector<std::string> valid_names,
    const char* parameters)
Guolin Ke's avatar
Guolin Ke committed
34
    :train_data_(train_data), valid_datas_(valid_data) {
Guolin Ke's avatar
Guolin Ke committed
35
36
37
    config_.LoadFromString(parameters);
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
38
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
39
40
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
41
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, ""));
Guolin Ke's avatar
Guolin Ke committed
42
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
43
44
45
46
47
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
Guolin Ke's avatar
Guolin Ke committed
48
    // create training metric
Guolin Ke's avatar
Guolin Ke committed
49
    for (auto metric_type : config_.metric_types) {
Guolin Ke's avatar
Guolin Ke committed
50
51
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
Guolin Ke's avatar
Guolin Ke committed
52
53
54
      if (metric == nullptr) { continue; }
      metric->Init("training", train_data_->metadata(),
        train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
55
      train_metric_.push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
56
    }
Guolin Ke's avatar
Guolin Ke committed
57
    train_metric_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
    // add metric for validation data
    for (size_t i = 0; i < valid_datas_.size(); ++i) {
      valid_metrics_.emplace_back();
      for (auto metric_type : config_.metric_types) {
Guolin Ke's avatar
Guolin Ke committed
62
        auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
        if (metric == nullptr) { continue; }
        metric->Init(valid_names[i].c_str(),
          valid_datas_[i]->metadata(),
          valid_datas_[i]->num_data());
Guolin Ke's avatar
Guolin Ke committed
67
        valid_metrics_.back().push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
68
      }
Guolin Ke's avatar
Guolin Ke committed
69
      valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
70
    }
Guolin Ke's avatar
Guolin Ke committed
71
    valid_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
72
    // initialize the objective function
Guolin Ke's avatar
Guolin Ke committed
73
74
75
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }
Guolin Ke's avatar
Guolin Ke committed
76
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
77
78
    boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    // add validation data into boosting
    for (size_t i = 0; i < valid_datas_.size(); ++i) {
      boosting_->AddDataset(valid_datas_[i],
Guolin Ke's avatar
Guolin Ke committed
82
        Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
    }
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
87

Guolin Ke's avatar
Guolin Ke committed
88
  }
89
90
91
92
93
94
95
96
97
98
99

  bool TrainOneIter() {
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

  void PrepareForPrediction(int num_used_model, int predict_type) {
    boosting_->SetNumUsedModel(num_used_model);
Guolin Ke's avatar
Guolin Ke committed
100
101
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
102
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
103
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
104
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
105
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
106
107
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
108
    }
Guolin Ke's avatar
Guolin Ke committed
109
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
  }

  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
114
115
  }

116
117
118
119
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
120
121
122
  void SaveModelToFile(int num_used_model, const char* filename) {
    boosting_->SaveModelToFile(num_used_model, true, filename);
  }
123
  
Guolin Ke's avatar
Guolin Ke committed
124
  const Boosting* GetBoosting() const { return boosting_.get(); }
125

126
127
  const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); }

Guolin Ke's avatar
Guolin Ke committed
128
  const inline int NumberOfClasses() const { return boosting_->NumberOfClasses(); }
129

Guolin Ke's avatar
Guolin Ke committed
130
private:
131

Guolin Ke's avatar
Guolin Ke committed
132
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
137
138
139
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Training data */
  const Dataset* train_data_;
  /*! \brief Validation data */
  std::vector<const Dataset*> valid_datas_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
140
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
141
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
142
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
143
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
144
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
145
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
146
  std::unique_ptr<Predictor> predictor_;
147

Guolin Ke's avatar
Guolin Ke committed
148
149
150
};

}
Guolin Ke's avatar
Guolin Ke committed
151
152
153

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
154
DllExport const char* LGBM_GetLastError() {
155
  return LastErrorMsg().c_str();
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
161
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
162
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
163
164
165
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
166
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
167
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
168
    *out = new std::shared_ptr<Dataset>(loader.LoadFromFile(filename));
Guolin Ke's avatar
Guolin Ke committed
169
  } else {
Guolin Ke's avatar
Guolin Ke committed
170
171
172
173
    *out = new std::shared_ptr<Dataset>(
      loader.LoadFromFileAlignWithOtherDataset(filename, 
        reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get())
      );
Guolin Ke's avatar
Guolin Ke committed
174
  }
175
  API_END();
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
}

DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
  DatesetHandle* out) {
180
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
181
182
  OverallConfig config;
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
183
  *out = new std::shared_ptr<Dataset>(loader.LoadFromBinFile(filename, 0, 1));
184
  API_END();
Guolin Ke's avatar
Guolin Ke committed
185
186
187
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
188
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
189
190
191
192
193
194
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
195
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
196
197
198
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
199
  std::unique_ptr<Dataset> ret;
200
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
205
  if (reference == nullptr) {
    // sample data first
    Random rand(config.io_config.data_random_seed);
    const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
206
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
207
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
208
      auto idx = sample_indices[i];
209
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
210
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
211
212
213
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
214
215
      }
    }
Guolin Ke's avatar
Guolin Ke committed
216
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
217
  } else {
Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
    ret.reset(new Dataset(nrow, config.io_config.num_class));
    ret->CopyFeatureMapperFrom(
      reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(),
      config.io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
227
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
228
229
230
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
231
  *out = new std::shared_ptr<Dataset>(ret.release());
232
  API_END();
233
234
}

235
236
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
237
238
  const int32_t* indices,
  const void* data,
239
240
241
242
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
243
244
245
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
246
  API_BEGIN();
247
248
249
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<Dataset> ret;
251
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
252
253
254
255
256
257
258
259
260
261
262
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
    Random rand(config.io_config.data_random_seed);
    const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
263
264
265
266
267
268
269
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
270
          }
Guolin Ke's avatar
Guolin Ke committed
271
272
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
273
274
275
        }
      }
    }
276
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
277
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
278
  } else {
Guolin Ke's avatar
Guolin Ke committed
279
280
281
282
    ret.reset(new Dataset(nrow, config.io_config.num_class));
    ret->CopyFeatureMapperFrom(
      reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(),
      config.io_config.is_enable_sparse);
283
284
285
286
287
288
289
290
291
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
292
  *out = new std::shared_ptr<Dataset>(ret.release());
293
  API_END();
294
295
}

296
297
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
298
299
  const int32_t* indices,
  const void* data,
300
301
302
303
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
304
305
306
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
307
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
308
309
310
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
311
  std::unique_ptr<Dataset> ret;
312
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
317
318
319
320
321
322
323
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
    Random rand(config.io_config.data_random_seed);
    const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
324
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
325
    }
Guolin Ke's avatar
Guolin Ke committed
326
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
327
  } else {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
    ret.reset(new Dataset(nrow, config.io_config.num_class));
    ret->CopyFeatureMapperFrom(
      reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(),
      config.io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
332
333
334
335
336
337
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
338
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
339
340
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
341
  *out = new std::shared_ptr<Dataset>(ret.release());
342
  API_END();
Guolin Ke's avatar
Guolin Ke committed
343
344
}

345
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
346
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
347
  delete reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
348
  API_END();
349
350
351
352
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
353
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
354
355
  auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
  dataset->get()->SaveBinaryFile(filename);
356
  API_END();
357
358
359
360
361
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
362
  int64_t num_element,
363
  int type) {
364
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
365
  auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
366
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
367
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
368
    is_success = dataset->get()->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
369
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
370
    is_success = dataset->get()->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
371
  }
372
373
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
374
375
376
377
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
378
  int64_t* out_len,
379
380
  const void** out_ptr,
  int* out_type) {
381
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
382
  auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
383
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
384
  if (dataset->get()->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
385
    *out_type = C_API_DTYPE_FLOAT32;
386
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
387
  } else if (dataset->get()->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
388
    *out_type = C_API_DTYPE_INT32;
389
    is_success = true;
390
  }
391
392
  if (!is_success) { throw std::runtime_error("Field not found"); }
  API_END();
393
394
395
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
396
  int64_t* out) {
397
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
398
399
  auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
  *out = dataset->get()->num_data();
400
  API_END();
401
402
403
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
404
  int64_t* out) {
405
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
406
407
  auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
  *out = dataset->get()->num_total_features();
408
  API_END();
Guolin Ke's avatar
Guolin Ke committed
409
}
410
411
412
413
414
415
416
417
418
419


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const DatesetHandle valid_datas[],
  const char* valid_names[],
  int n_valid_datas,
  const char* parameters,
  BoosterHandle* out) {
420
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
421
  const Dataset* p_train_data = reinterpret_cast<const std::shared_ptr<Dataset>*>(train_data)->get();
422
423
424
  std::vector<const Dataset*> p_valid_datas;
  std::vector<std::string> p_valid_names;
  for (int i = 0; i < n_valid_datas; ++i) {
Guolin Ke's avatar
Guolin Ke committed
425
    p_valid_datas.emplace_back(reinterpret_cast<const std::shared_ptr<Dataset>*>(valid_datas[i])->get());
426
427
    p_valid_names.emplace_back(valid_names[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
428
  *out = new std::shared_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters));
429
  API_END();
430
431
432
433
434
}

DllExport int LGBM_BoosterLoadFromModelfile(
  const char* filename,
  BoosterHandle* out) {
435
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
436
  *out = new std::shared_ptr<Booster>(new Booster(filename));
437
  API_END();
438
439
440
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
441
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
442
  delete reinterpret_cast<std::shared_ptr<Booster>*>(handle);
443
  API_END();
444
445
446
}

DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
447
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
448
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
449
450
451
452
453
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
454
  API_END();
455
456
457
458
459
460
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
461
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
462
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
463
464
465
466
467
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
468
  API_END();
469
470
471
472
}

DllExport int LGBM_BoosterEval(BoosterHandle handle,
  int data,
473
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
474
  float* out_results) {
475
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
476
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
477
478
  auto boosting = ref_booster->GetBoosting();
  auto result_buf = boosting->GetEvalAt(data);
479
  *out_len = static_cast<int64_t>(result_buf.size());
480
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
481
    (out_results)[i] = static_cast<float>(result_buf[i]);
482
  }
483
  API_END();
484
485
486
}

DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
487
  int64_t* out_len,
488
  const float** out_result) {
489
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
490
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
491
  int len = 0;
492
  *out_result = ref_booster->GetTrainingScore(&len);
493
  *out_len = static_cast<int64_t>(len);
494
  API_END();
495
496
}

Guolin Ke's avatar
Guolin Ke committed
497
498
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
  int data,
499
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
500
  float* out_result) {
501
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
502
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Guolin Ke's avatar
Guolin Ke committed
503
504
  auto boosting = ref_booster->GetBoosting();
  int len = 0;
Guolin Ke's avatar
Guolin Ke committed
505
  boosting->GetPredictAt(data, out_result, &len);
506
  *out_len = static_cast<int64_t>(len);
507
  API_END();
Guolin Ke's avatar
Guolin Ke committed
508
509
}

510
511
512
513
514
515
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  int predict_type,
  int64_t n_used_trees,
  int data_has_header,
  const char* data_filename,
  const char* result_filename) {
516
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
517
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
518
519
520
  ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
521
  API_END();
522
523
}

524
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
525
526
  const void* indptr,
  int indptr_type,
527
528
  const int32_t* indices,
  const void* data,
529
530
531
532
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
533
  int predict_type,
534
  int64_t n_used_trees,
Guolin Ke's avatar
Guolin Ke committed
535
  double* out_result) {
536
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
537
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Guolin Ke's avatar
Guolin Ke committed
538
539
  ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);

540
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
541
  int num_class = ref_booster->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
542
543
544
545
546
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
547
    for (int j = 0; j < num_class; ++j) {
Guolin Ke's avatar
Guolin Ke committed
548
549
550
      out_result[i * num_class + j] = predicton_result[j];
    }
  }
551
  API_END();
Guolin Ke's avatar
Guolin Ke committed
552
}
553
554
555

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
556
  int data_type,
557
558
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
559
  int is_row_major,
560
  int predict_type,
561
  int64_t n_used_trees,
Guolin Ke's avatar
Guolin Ke committed
562
  double* out_result) {
563
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
564
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Guolin Ke's avatar
Guolin Ke committed
565
566
  ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);

567
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
568
  int num_class = ref_booster->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
569
570
571
572
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
573
    for (int j = 0; j < num_class; ++j) {
Guolin Ke's avatar
Guolin Ke committed
574
575
576
      out_result[i * num_class + j] = predicton_result[j];
    }
  }
577
  API_END();
Guolin Ke's avatar
Guolin Ke committed
578
}
579
580
581

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
  int num_used_model,
Guolin Ke's avatar
Guolin Ke committed
582
  const char* filename) {
583
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
584
  Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Guolin Ke's avatar
Guolin Ke committed
585
  ref_booster->SaveModelToFile(num_used_model, filename);
586
  API_END();
Guolin Ke's avatar
Guolin Ke committed
587
}
588

Guolin Ke's avatar
Guolin Ke committed
589
// ---- start of some help functions
590
591
592

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
593
  if (data_type == C_API_DTYPE_FLOAT32) {
594
595
596
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
597
        std::vector<double> ret(num_col);
598
599
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
600
          ret[i] = static_cast<double>(*(tmp_ptr + i));
601
602
603
604
605
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
606
        std::vector<double> ret(num_col);
607
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
608
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
609
610
611
612
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
613
  } else if (data_type == C_API_DTYPE_FLOAT64) {
614
615
616
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
617
        std::vector<double> ret(num_col);
618
619
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
620
          ret[i] = static_cast<double>(*(tmp_ptr + i));
621
622
623
624
625
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
626
        std::vector<double> ret(num_col);
627
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
628
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
629
630
631
632
633
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
634
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
635
636
637
638
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
639
640
641
642
643
644
645
646
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
647
        }
Guolin Ke's avatar
Guolin Ke committed
648
649
650
      }
      return ret;
    };
651
  }
Guolin Ke's avatar
Guolin Ke committed
652
  return nullptr;
653
654
655
656
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
657
  if (data_type == C_API_DTYPE_FLOAT32) {
658
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
659
    if (indptr_type == C_API_DTYPE_INT32) {
660
661
662
663
664
665
666
667
668
669
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
670
    } else if (indptr_type == C_API_DTYPE_INT64) {
671
672
673
674
675
676
677
678
679
680
681
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
682
  } else if (data_type == C_API_DTYPE_FLOAT64) {
683
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
684
    if (indptr_type == C_API_DTYPE_INT32) {
685
686
687
688
689
690
691
692
693
694
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
695
    } else if (indptr_type == C_API_DTYPE_INT64) {
696
697
698
699
700
701
702
703
704
705
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
706
707
708
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
709
710
711
712
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
713
  if (data_type == C_API_DTYPE_FLOAT32) {
714
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
715
    if (col_ptr_type == C_API_DTYPE_INT32) {
716
717
718
719
720
721
722
723
724
725
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
726
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
727
728
729
730
731
732
733
734
735
736
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
737
    } 
Guolin Ke's avatar
Guolin Ke committed
738
  } else if (data_type == C_API_DTYPE_FLOAT64) {
739
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
740
    if (col_ptr_type == C_API_DTYPE_INT32) {
741
742
743
744
745
746
747
748
749
750
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
751
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
752
753
754
755
756
757
758
759
760
761
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
762
763
764
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
}

std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices) {
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
}