"python-package/vscode:/vscode.git/clone" did not exist on "9a4e70687d5c0732ca895959f418c3f923f2e85a"
c_api.cpp 32.6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
22
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
23

Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
28
29
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
30
31
32
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
33
34
35
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
36
37
38
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
39
40
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
41
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
42
43
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
44

wxchan's avatar
wxchan committed
45
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
46

Guolin Ke's avatar
Guolin Ke committed
47
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
48
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
49
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
50
51

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
52
53
54
55
56
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
57
58
59
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
60

Guolin Ke's avatar
Guolin Ke committed
61
  }
62

wxchan's avatar
wxchan committed
63
64
65
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
wxchan's avatar
wxchan committed
88
89
90
91
92
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
94
95
96
97
98
99
100
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
101
102
103
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
104
105

    config_.Set(param);
106
107
108
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
116
117
118
119
120

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    
wxchan's avatar
wxchan committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
141

142
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
143
    std::lock_guard<std::mutex> lock(mutex_);
144
145
146
147
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
148
    std::lock_guard<std::mutex> lock(mutex_);
149
150
151
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
152
153
154
155
156
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
157
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
158
159
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
160
161
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
162
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
163
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
164
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
165
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
166
167
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    }
Guolin Ke's avatar
Guolin Ke committed
169
170
171
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
172
173
  }

174
  void GetPredictAt(int data_idx, score_t* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
175
176
177
178
179
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
180
  }
181

wxchan's avatar
wxchan committed
182
183
184
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
185

Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
191
192
193
194
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
195
196
197
198
199
200
201
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
202

wxchan's avatar
wxchan committed
203
204
205
206
207
208
209
210
211
212
213
214
215
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
216
private:
217

wxchan's avatar
wxchan committed
218
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
219
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
220
221
222
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
223
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
224
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
225
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
226
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
227
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
228
229
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
230
231
232
};

}
Guolin Ke's avatar
Guolin Ke committed
233
234
235

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
236
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
237
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
238
239
}

wxchan's avatar
wxchan committed
240
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
241
  const char* parameters,
242
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
243
  DatasetHandle* out) {
244
  API_BEGIN();
wxchan's avatar
wxchan committed
245
246
247
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
248
  DatasetLoader loader(io_config, nullptr, filename);
Guolin Ke's avatar
Guolin Ke committed
249
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
250
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
251
  } else {
Guolin Ke's avatar
Guolin Ke committed
252
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
253
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
254
  }
255
  API_END();
Guolin Ke's avatar
Guolin Ke committed
256
257
}

wxchan's avatar
wxchan committed
258
DllExport int LGBM_DatasetCreateFromMat(const void* data,
259
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
264
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
265
  DatasetHandle* out) {
266
  API_BEGIN();
wxchan's avatar
wxchan committed
267
268
269
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
270
  std::unique_ptr<Dataset> ret;
271
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
272
273
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
274
275
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
276
    auto sample_indices = rand.Sample(nrow, sample_cnt);
277
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
278
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
279
      auto idx = sample_indices[i];
280
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
281
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
282
283
284
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
285
286
      }
    }
Guolin Ke's avatar
Guolin Ke committed
287
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
288
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
289
  } else {
wxchan's avatar
wxchan committed
290
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
291
    ret->CopyFeatureMapperFrom(
292
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
293
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
298
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
299
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
300
301
302
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
303
  *out = ret.release();
304
  API_END();
305
306
}

wxchan's avatar
wxchan committed
307
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
308
  int indptr_type,
309
310
  const int32_t* indices,
  const void* data,
311
312
313
314
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
315
  const char* parameters,
316
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
317
  DatasetHandle* out) {
318
  API_BEGIN();
wxchan's avatar
wxchan committed
319
320
321
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
322
  std::unique_ptr<Dataset> ret;
323
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
324
325
326
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
327
328
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
329
330
331
332
333
334
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
335
336
337
338
339
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
340
          }
341
342
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
343
344
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
345
346
347
        }
      }
    }
348
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
349
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
350
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
351
  } else {
wxchan's avatar
wxchan committed
352
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
353
    ret->CopyFeatureMapperFrom(
354
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
355
      io_config.is_enable_sparse);
356
357
358
359
360
361
362
363
364
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
365
  *out = ret.release();
366
  API_END();
367
368
}

wxchan's avatar
wxchan committed
369
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
370
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
371
372
  const int32_t* indices,
  const void* data,
373
374
375
376
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
377
  const char* parameters,
378
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
379
  DatasetHandle* out) {
380
  API_BEGIN();
wxchan's avatar
wxchan committed
381
382
383
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
384
  std::unique_ptr<Dataset> ret;
385
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
386
387
388
389
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
390
391
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
392
393
394
395
396
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
397
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
398
    }
Guolin Ke's avatar
Guolin Ke committed
399
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
400
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
401
  } else {
wxchan's avatar
wxchan committed
402
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
403
    ret->CopyFeatureMapperFrom(
404
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
405
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
406
407
408
409
410
411
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
412
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
413
414
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
415
  *out = ret.release();
416
  API_END();
Guolin Ke's avatar
Guolin Ke committed
417
418
}

wxchan's avatar
wxchan committed
419
DllExport int LGBM_DatasetGetSubset(
420
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
421
422
423
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
424
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
425
426
427
428
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
429
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
430
431
432
433
434
435
436
437
438
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
439
DllExport int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
440
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
441
442
443
444
445
446
447
448
449
450
451
452
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
typo  
Guolin Ke committed
453
DllExport int LGBM_DatasetFree(DatasetHandle handle) {
454
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
455
  delete reinterpret_cast<Dataset*>(handle);
456
  API_END();
457
458
}

Guolin Ke's avatar
typo  
Guolin Ke committed
459
DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
460
  const char* filename) {
461
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
462
463
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
464
  API_END();
465
466
}

Guolin Ke's avatar
typo  
Guolin Ke committed
467
DllExport int LGBM_DatasetSetField(DatasetHandle handle,
468
469
  const char* field_name,
  const void* field_data,
470
  int64_t num_element,
471
  int type) {
472
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
473
  auto dataset = reinterpret_cast<Dataset*>(handle);
474
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
475
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
476
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
477
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
478
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
479
  }
480
481
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
482
483
}

Guolin Ke's avatar
typo  
Guolin Ke committed
484
DllExport int LGBM_DatasetGetField(DatasetHandle handle,
485
  const char* field_name,
486
  int64_t* out_len,
487
488
  const void** out_ptr,
  int* out_type) {
489
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
490
  auto dataset = reinterpret_cast<Dataset*>(handle);
491
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
492
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
493
    *out_type = C_API_DTYPE_FLOAT32;
494
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
495
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
496
    *out_type = C_API_DTYPE_INT32;
497
    is_success = true;
498
  }
499
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
500
  if (*out_ptr == nullptr) { *out_len = 0; }
501
  API_END();
502
503
}

Guolin Ke's avatar
typo  
Guolin Ke committed
504
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
505
  int64_t* out) {
506
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
507
508
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
509
  API_END();
510
511
}

Guolin Ke's avatar
typo  
Guolin Ke committed
512
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
513
  int64_t* out) {
514
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
515
516
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
517
  API_END();
Guolin Ke's avatar
Guolin Ke committed
518
}
519
520
521
522


// ---- start of booster

Guolin Ke's avatar
typo  
Guolin Ke committed
523
DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
524
525
  const char* parameters,
  BoosterHandle* out) {
526
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
527
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
528
529
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
530
  API_END();
531
532
}

wxchan's avatar
wxchan committed
533
DllExport int LGBM_BoosterCreateFromModelfile(
534
  const char* filename,
wxchan's avatar
wxchan committed
535
  int64_t* out_num_iterations,
536
  BoosterHandle* out) {
537
  API_BEGIN();
wxchan's avatar
wxchan committed
538
539
540
541
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
542
  API_END();
543
544
545
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
546
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
547
  delete reinterpret_cast<Booster*>(handle);
548
  API_END();
549
550
}

wxchan's avatar
wxchan committed
551
552
553
554
555
556
557
558
559
560
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
561
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
562
563
564
565
566
567
568
569
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
570
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

592
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
593
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
594
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
595
596
597
598
599
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
600
  API_END();
601
602
603
604
605
606
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
607
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
608
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
609
610
611
612
613
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
614
  API_END();
615
616
}

wxchan's avatar
wxchan committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
655
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
656
  float* out_results) {
657
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
658
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
659
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
660
  auto result_buf = boosting->GetEvalAt(data_idx);
661
  *out_len = static_cast<int64_t>(result_buf.size());
662
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
663
    (out_results)[i] = static_cast<float>(result_buf[i]);
664
  }
665
  API_END();
666
667
}

Guolin Ke's avatar
Guolin Ke committed
668
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
669
  int data_idx,
670
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
671
  float* out_result) {
672
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
673
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
674
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
675
  API_END();
Guolin Ke's avatar
Guolin Ke committed
676
677
}

678
679
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
680
681
682
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
683
  const char* result_filename) {
684
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
685
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
686
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
687
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
688
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
689
  API_END();
690
691
}

692
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
693
694
  const void* indptr,
  int indptr_type,
695
696
  const int32_t* indices,
  const void* data,
697
698
699
700
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
701
  int predict_type,
wxchan's avatar
wxchan committed
702
703
704
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
705
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
706
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
707
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
708
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
709
710
711
712
713
714
715
716
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
717
718
719
720
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
721
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
722
723
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
724
725
    }
  }
wxchan's avatar
wxchan committed
726
  *out_len = nrow * num_preb_in_one_row;
727
  API_END();
Guolin Ke's avatar
Guolin Ke committed
728
}
729
730
731

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
732
  int data_type,
733
734
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
735
  int is_row_major,
736
  int predict_type,
wxchan's avatar
wxchan committed
737
738
739
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
740
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
741
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
742
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
743
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
744
745
746
747
748
749
750
751
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
752
753
754
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
755
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
756
757
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
758
759
    }
  }
wxchan's avatar
wxchan committed
760
  *out_len = nrow * num_preb_in_one_row;
761
  API_END();
Guolin Ke's avatar
Guolin Ke committed
762
}
763
764

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
765
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
766
  const char* filename) {
767
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
768
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
769
770
771
772
773
774
775
776
777
778
779
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
780
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
781
782
783
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
784
  API_END();
Guolin Ke's avatar
Guolin Ke committed
785
}
786

Guolin Ke's avatar
Guolin Ke committed
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809


DllExport int LGBM_BoosterGetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
  float* out_val) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_val = static_cast<float>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
  API_END();
}


DllExport int LGBM_BoosterSetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
  float val) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->SetLeafValue(tree_idx, leaf_idx, static_cast<double>(val));
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
810
// ---- start of some help functions
811
812
813

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
814
  if (data_type == C_API_DTYPE_FLOAT32) {
815
816
817
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
818
        std::vector<double> ret(num_col);
819
820
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
821
          ret[i] = static_cast<double>(*(tmp_ptr + i));
822
823
824
825
826
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
827
        std::vector<double> ret(num_col);
828
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
829
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
830
831
832
833
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
834
  } else if (data_type == C_API_DTYPE_FLOAT64) {
835
836
837
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
838
        std::vector<double> ret(num_col);
839
840
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
841
          ret[i] = static_cast<double>(*(tmp_ptr + i));
842
843
844
845
846
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
847
        std::vector<double> ret(num_col);
848
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
849
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
850
851
852
853
854
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
855
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
856
857
858
859
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
860
861
862
863
864
865
866
867
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
868
        }
Guolin Ke's avatar
Guolin Ke committed
869
870
871
      }
      return ret;
    };
872
  }
Guolin Ke's avatar
Guolin Ke committed
873
  return nullptr;
874
875
876
877
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
878
  if (data_type == C_API_DTYPE_FLOAT32) {
879
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
880
    if (indptr_type == C_API_DTYPE_INT32) {
881
882
883
884
885
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
886
        for (int64_t i = start; i < end; ++i) {
887
888
889
890
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
891
    } else if (indptr_type == C_API_DTYPE_INT64) {
892
893
894
895
896
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
897
        for (int64_t i = start; i < end; ++i) {
898
899
900
901
902
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
903
  } else if (data_type == C_API_DTYPE_FLOAT64) {
904
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
905
    if (indptr_type == C_API_DTYPE_INT32) {
906
907
908
909
910
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
911
        for (int64_t i = start; i < end; ++i) {
912
913
914
915
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
916
    } else if (indptr_type == C_API_DTYPE_INT64) {
917
918
919
920
921
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
922
        for (int64_t i = start; i < end; ++i) {
923
924
925
926
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
927
928
929
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
930
931
932
933
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
934
  if (data_type == C_API_DTYPE_FLOAT32) {
935
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
936
    if (col_ptr_type == C_API_DTYPE_INT32) {
937
938
939
940
941
942
943
944
945
946
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
947
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
948
949
950
951
952
953
954
955
956
957
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
958
    } 
Guolin Ke's avatar
Guolin Ke committed
959
  } else if (data_type == C_API_DTYPE_FLOAT64) {
960
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
961
    if (col_ptr_type == C_API_DTYPE_INT32) {
962
963
964
965
966
967
968
969
970
971
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
972
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
973
974
975
976
977
978
979
980
981
982
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
983
984
985
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
986
987
}

988
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
989
990
991
992
993
994
995
996
997
998
999
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
1000
}