c_api.cpp 32.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
32
33
34
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
35
36
37
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
38
39
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
40
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
41
42
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
43

wxchan's avatar
wxchan committed
44
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
45

Guolin Ke's avatar
Guolin Ke committed
46
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
47
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
48
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
49
50

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
51
52
53
54
55
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
56
57
58
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
59

Guolin Ke's avatar
Guolin Ke committed
60
  }
61

wxchan's avatar
wxchan committed
62
63
64
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
wxchan's avatar
wxchan committed
87
88
89
90
91
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
92
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
93
94
95
96
97
98
99
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
100
101
102
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
103
104

    config_.Set(param);
105
106
107
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
114
115
116
117
118
119

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    
wxchan's avatar
wxchan committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
140

141
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
142
    std::lock_guard<std::mutex> lock(mutex_);
143
144
145
146
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
147
    std::lock_guard<std::mutex> lock(mutex_);
148
149
150
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
151
152
153
154
155
156
157
158
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

  void PrepareForPrediction(int num_iteration, int predict_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
159
160
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
161
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
162
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
163
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
164
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
165
166
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
167
    }
Guolin Ke's avatar
Guolin Ke committed
168
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
169
170
  }

wxchan's avatar
wxchan committed
171
172
173
174
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
175
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
Guolin Ke's avatar
Guolin Ke committed
176
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
177
    return predictor_->GetPredictFunction()(features);
178
179
  }

180
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
Guolin Ke's avatar
Guolin Ke committed
181
    std::lock_guard<std::mutex> lock(mutex_);
182
183
184
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

wxchan's avatar
wxchan committed
185
186
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
187
  }
188

wxchan's avatar
wxchan committed
189
190
191
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
192

wxchan's avatar
wxchan committed
193
194
195
196
197
198
199
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
200

wxchan's avatar
wxchan committed
201
202
203
204
205
206
207
208
209
210
211
212
213
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
214
private:
215

wxchan's avatar
wxchan committed
216
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
217
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
218
219
220
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
221
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
222
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
223
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
224
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
225
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
226
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
227
  std::unique_ptr<Predictor> predictor_;
wxchan's avatar
wxchan committed
228
229
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
230
231
232
};

}
Guolin Ke's avatar
Guolin Ke committed
233
234
235

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
236
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
237
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
238
239
}

wxchan's avatar
wxchan committed
240
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
241
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
242
243
  const DatasetHandle* reference,
  DatasetHandle* out) {
244
  API_BEGIN();
wxchan's avatar
wxchan committed
245
246
247
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
248
  DatasetLoader loader(io_config, nullptr, filename);
Guolin Ke's avatar
Guolin Ke committed
249
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
250
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
251
  } else {
Guolin Ke's avatar
Guolin Ke committed
252
253
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
254
  }
255
  API_END();
Guolin Ke's avatar
Guolin Ke committed
256
257
}

wxchan's avatar
wxchan committed
258
DllExport int LGBM_DatasetCreateFromMat(const void* data,
259
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
264
265
  const DatasetHandle* reference,
  DatasetHandle* out) {
266
  API_BEGIN();
wxchan's avatar
wxchan committed
267
268
269
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
270
  std::unique_ptr<Dataset> ret;
271
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
272
273
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
274
275
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
276
    auto sample_indices = rand.Sample(nrow, sample_cnt);
277
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
278
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
279
      auto idx = sample_indices[i];
280
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
281
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
282
283
284
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
285
286
      }
    }
Guolin Ke's avatar
Guolin Ke committed
287
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
288
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
289
  } else {
wxchan's avatar
wxchan committed
290
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
291
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
292
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
293
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
298
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
299
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
300
301
302
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
303
  *out = ret.release();
304
  API_END();
305
306
}

wxchan's avatar
wxchan committed
307
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
308
  int indptr_type,
309
310
  const int32_t* indices,
  const void* data,
311
312
313
314
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
315
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
316
317
  const DatasetHandle* reference,
  DatasetHandle* out) {
318
  API_BEGIN();
wxchan's avatar
wxchan committed
319
320
321
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
322
  std::unique_ptr<Dataset> ret;
323
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
324
325
326
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
327
328
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
329
330
331
332
333
334
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
335
336
337
338
339
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
340
          }
341
342
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
343
344
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
345
346
347
        }
      }
    }
348
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
349
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
350
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
351
  } else {
wxchan's avatar
wxchan committed
352
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
353
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
354
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
355
      io_config.is_enable_sparse);
356
357
358
359
360
361
362
363
364
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
365
  *out = ret.release();
366
  API_END();
367
368
}

wxchan's avatar
wxchan committed
369
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
370
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
371
372
  const int32_t* indices,
  const void* data,
373
374
375
376
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
377
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
378
379
  const DatasetHandle* reference,
  DatasetHandle* out) {
380
  API_BEGIN();
wxchan's avatar
wxchan committed
381
382
383
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
384
  std::unique_ptr<Dataset> ret;
385
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
386
387
388
389
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
390
391
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
392
393
394
395
396
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
397
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
398
    }
Guolin Ke's avatar
Guolin Ke committed
399
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
400
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
401
  } else {
wxchan's avatar
wxchan committed
402
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
403
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
404
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
405
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
406
407
408
409
410
411
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
412
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
413
414
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
415
  *out = ret.release();
416
  API_END();
Guolin Ke's avatar
Guolin Ke committed
417
418
}

wxchan's avatar
wxchan committed
419
DllExport int LGBM_DatasetGetSubset(
Guolin Ke's avatar
typo  
Guolin Ke committed
420
  const DatasetHandle* handle,
wxchan's avatar
wxchan committed
421
422
423
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
424
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
439
DllExport int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
440
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
441
442
443
444
445
446
447
448
449
450
451
452
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
typo  
Guolin Ke committed
453
DllExport int LGBM_DatasetFree(DatasetHandle handle) {
454
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
455
  delete reinterpret_cast<Dataset*>(handle);
456
  API_END();
457
458
}

Guolin Ke's avatar
typo  
Guolin Ke committed
459
DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
460
  const char* filename) {
461
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
462
463
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
464
  API_END();
465
466
}

Guolin Ke's avatar
typo  
Guolin Ke committed
467
DllExport int LGBM_DatasetSetField(DatasetHandle handle,
468
469
  const char* field_name,
  const void* field_data,
470
  int64_t num_element,
471
  int type) {
472
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
473
  auto dataset = reinterpret_cast<Dataset*>(handle);
474
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
475
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
476
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
477
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
478
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
479
  }
480
481
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
482
483
}

Guolin Ke's avatar
typo  
Guolin Ke committed
484
DllExport int LGBM_DatasetGetField(DatasetHandle handle,
485
  const char* field_name,
486
  int64_t* out_len,
487
488
  const void** out_ptr,
  int* out_type) {
489
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
490
  auto dataset = reinterpret_cast<Dataset*>(handle);
491
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
492
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
493
    *out_type = C_API_DTYPE_FLOAT32;
494
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
495
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
496
    *out_type = C_API_DTYPE_INT32;
497
    is_success = true;
498
  }
499
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
500
  if (*out_ptr == nullptr) { *out_len = 0; }
501
  API_END();
502
503
}

Guolin Ke's avatar
typo  
Guolin Ke committed
504
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
505
  int64_t* out) {
506
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
507
508
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
509
  API_END();
510
511
}

Guolin Ke's avatar
typo  
Guolin Ke committed
512
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
513
  int64_t* out) {
514
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
515
516
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
517
  API_END();
Guolin Ke's avatar
Guolin Ke committed
518
}
519
520
521
522


// ---- start of booster

Guolin Ke's avatar
typo  
Guolin Ke committed
523
DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
524
525
  const char* parameters,
  BoosterHandle* out) {
526
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
527
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
528
529
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
530
  API_END();
531
532
}

wxchan's avatar
wxchan committed
533
DllExport int LGBM_BoosterCreateFromModelfile(
534
  const char* filename,
wxchan's avatar
wxchan committed
535
  int64_t* out_num_iterations,
536
  BoosterHandle* out) {
537
  API_BEGIN();
wxchan's avatar
wxchan committed
538
539
540
541
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
542
  API_END();
543
544
545
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
546
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
547
  delete reinterpret_cast<Booster*>(handle);
548
  API_END();
549
550
}

wxchan's avatar
wxchan committed
551
552
553
554
555
556
557
558
559
560
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
561
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
562
563
564
565
566
567
568
569
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
570
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

592
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
593
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
594
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
595
596
597
598
599
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
600
  API_END();
601
602
603
604
605
606
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
607
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
608
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
609
610
611
612
613
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
614
  API_END();
615
616
}

wxchan's avatar
wxchan committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
655
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
656
  float* out_results) {
657
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
658
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
659
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
660
  auto result_buf = boosting->GetEvalAt(data_idx);
661
  *out_len = static_cast<int64_t>(result_buf.size());
662
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
663
    (out_results)[i] = static_cast<float>(result_buf[i]);
664
  }
665
  API_END();
666
667
}

Guolin Ke's avatar
Guolin Ke committed
668
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
669
  int data_idx,
670
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
671
  float* out_result) {
672
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
673
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
674
  int len = 0;
wxchan's avatar
wxchan committed
675
  ref_booster->GetPredictAt(data_idx, out_result, &len);
676
  *out_len = static_cast<int64_t>(len);
677
  API_END();
Guolin Ke's avatar
Guolin Ke committed
678
679
}

680
681
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
682
683
684
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
685
  const char* result_filename) {
686
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
687
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
688
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
689
690
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
691
  API_END();
692
693
}

694
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
695
696
  const void* indptr,
  int indptr_type,
697
698
  const int32_t* indices,
  const void* data,
699
700
701
702
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
703
  int predict_type,
wxchan's avatar
wxchan committed
704
705
706
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
707
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
708
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
709
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
710

711
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
712
713
714
715
716
717
718
719
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
720
721
722
723
724
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
725
726
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
727
728
    }
  }
wxchan's avatar
wxchan committed
729
  *out_len = nrow * num_preb_in_one_row;
730
  API_END();
Guolin Ke's avatar
Guolin Ke committed
731
}
732
733
734

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
735
  int data_type,
736
737
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
738
  int is_row_major,
739
  int predict_type,
wxchan's avatar
wxchan committed
740
741
742
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
743
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
744
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
745
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
746

747
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
748
749
750
751
752
753
754
755
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
756
757
758
759
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
760
761
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
762
763
    }
  }
wxchan's avatar
wxchan committed
764
  *out_len = nrow * num_preb_in_one_row;
765
  API_END();
Guolin Ke's avatar
Guolin Ke committed
766
}
767
768

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
769
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
770
  const char* filename) {
771
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
772
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
773
774
775
776
777
778
779
780
781
782
783
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
784
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
785
786
787
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
788
  API_END();
Guolin Ke's avatar
Guolin Ke committed
789
}
790

Guolin Ke's avatar
Guolin Ke committed
791
// ---- start of some help functions
792
793
794

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
795
  if (data_type == C_API_DTYPE_FLOAT32) {
796
797
798
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
799
        std::vector<double> ret(num_col);
800
801
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
802
          ret[i] = static_cast<double>(*(tmp_ptr + i));
803
804
805
806
807
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
808
        std::vector<double> ret(num_col);
809
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
810
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
811
812
813
814
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
815
  } else if (data_type == C_API_DTYPE_FLOAT64) {
816
817
818
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
819
        std::vector<double> ret(num_col);
820
821
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
822
          ret[i] = static_cast<double>(*(tmp_ptr + i));
823
824
825
826
827
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
828
        std::vector<double> ret(num_col);
829
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
830
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
831
832
833
834
835
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
836
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
837
838
839
840
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
841
842
843
844
845
846
847
848
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
849
        }
Guolin Ke's avatar
Guolin Ke committed
850
851
852
      }
      return ret;
    };
853
  }
Guolin Ke's avatar
Guolin Ke committed
854
  return nullptr;
855
856
857
858
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
859
  if (data_type == C_API_DTYPE_FLOAT32) {
860
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
861
    if (indptr_type == C_API_DTYPE_INT32) {
862
863
864
865
866
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
867
        for (int64_t i = start; i < end; ++i) {
868
869
870
871
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
872
    } else if (indptr_type == C_API_DTYPE_INT64) {
873
874
875
876
877
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
878
        for (int64_t i = start; i < end; ++i) {
879
880
881
882
883
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
884
  } else if (data_type == C_API_DTYPE_FLOAT64) {
885
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
886
    if (indptr_type == C_API_DTYPE_INT32) {
887
888
889
890
891
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
892
        for (int64_t i = start; i < end; ++i) {
893
894
895
896
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
897
    } else if (indptr_type == C_API_DTYPE_INT64) {
898
899
900
901
902
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
903
        for (int64_t i = start; i < end; ++i) {
904
905
906
907
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
908
909
910
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
911
912
913
914
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
915
  if (data_type == C_API_DTYPE_FLOAT32) {
916
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
917
    if (col_ptr_type == C_API_DTYPE_INT32) {
918
919
920
921
922
923
924
925
926
927
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
928
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
929
930
931
932
933
934
935
936
937
938
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
939
    } 
Guolin Ke's avatar
Guolin Ke committed
940
  } else if (data_type == C_API_DTYPE_FLOAT64) {
941
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
942
    if (col_ptr_type == C_API_DTYPE_INT32) {
943
944
945
946
947
948
949
950
951
952
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
953
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
954
955
956
957
958
959
960
961
962
963
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
964
965
966
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
967
968
}

969
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
970
971
972
973
974
975
976
977
978
979
980
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
981
}