c_api.cpp 31.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
32
33
34
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
wxchan's avatar
wxchan committed
40
41
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
    ConstructObjectAndTrainingMetrics(train_data);
Guolin Ke's avatar
Guolin Ke committed
42
    // initialize the boosting
wxchan's avatar
wxchan committed
43
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
44
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
45
46
47
48
49
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
50
51
52
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
53

Guolin Ke's avatar
Guolin Ke committed
54
  }
55

wxchan's avatar
wxchan committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
    // initialize the boosting
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
      // only need to set learning rate
      boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
    } else {
      ResetTrainingData(train_data_);
    }
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
96
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
97
    std::lock_guard<std::mutex> lock(mutex_);
98
99
100
101
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
102
    std::lock_guard<std::mutex> lock(mutex_);
103
104
105
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
106
107
108
109
110
111
112
113
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

  void PrepareForPrediction(int num_iteration, int predict_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
114
115
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
116
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
117
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
118
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
119
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
120
121
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
122
    }
Guolin Ke's avatar
Guolin Ke committed
123
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
124
125
  }

wxchan's avatar
wxchan committed
126
127
128
129
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
130
131
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
132
133
  }

134
135
136
137
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

wxchan's avatar
wxchan committed
138
139
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
140
  }
141

wxchan's avatar
wxchan committed
142
143
144
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
145

wxchan's avatar
wxchan committed
146
147
148
149
150
151
152
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
153

wxchan's avatar
wxchan committed
154
155
156
157
158
159
160
161
162
163
164
165
166
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
167
private:
168

wxchan's avatar
wxchan committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
193
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
194
195
196
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
197
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
198
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
199
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
200
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
201
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
202
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
203
  std::unique_ptr<Predictor> predictor_;
wxchan's avatar
wxchan committed
204
205
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
206
207
208
};

}
Guolin Ke's avatar
Guolin Ke committed
209
210
211

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
212
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
213
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
214
215
}

wxchan's avatar
wxchan committed
216
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
217
218
219
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
220
  API_BEGIN();
wxchan's avatar
wxchan committed
221
222
223
224
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
225
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
226
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
227
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
228
  } else {
Guolin Ke's avatar
Guolin Ke committed
229
230
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
231
  }
232
  API_END();
Guolin Ke's avatar
Guolin Ke committed
233
234
}

wxchan's avatar
wxchan committed
235
DllExport int LGBM_DatasetCreateFromMat(const void* data,
236
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
243
  API_BEGIN();
wxchan's avatar
wxchan committed
244
245
246
247
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
248
  std::unique_ptr<Dataset> ret;
249
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
250
251
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
252
253
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
254
    auto sample_indices = rand.Sample(nrow, sample_cnt);
255
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
256
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
257
      auto idx = sample_indices[i];
258
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
259
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
260
261
262
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
263
264
      }
    }
Guolin Ke's avatar
Guolin Ke committed
265
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
266
  } else {
wxchan's avatar
wxchan committed
267
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
268
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
269
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
270
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
271
272
273
274
275
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
276
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
277
278
279
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
280
  *out = ret.release();
281
  API_END();
282
283
}

wxchan's avatar
wxchan committed
284
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
285
  int indptr_type,
286
287
  const int32_t* indices,
  const void* data,
288
289
290
291
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
292
293
294
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
295
  API_BEGIN();
wxchan's avatar
wxchan committed
296
297
298
299
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
300
  std::unique_ptr<Dataset> ret;
301
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
302
303
304
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
305
306
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
307
308
309
310
311
312
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
317
318
319
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
320
          }
Guolin Ke's avatar
Guolin Ke committed
321
322
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
323
324
325
        }
      }
    }
326
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
327
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
328
  } else {
wxchan's avatar
wxchan committed
329
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
330
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
331
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
332
      io_config.is_enable_sparse);
333
334
335
336
337
338
339
340
341
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
342
  *out = ret.release();
343
  API_END();
344
345
}

wxchan's avatar
wxchan committed
346
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
347
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
348
349
  const int32_t* indices,
  const void* data,
350
351
352
353
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
354
355
356
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
357
  API_BEGIN();
wxchan's avatar
wxchan committed
358
359
360
361
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
362
  std::unique_ptr<Dataset> ret;
363
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
364
365
366
367
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
368
369
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
370
371
372
373
374
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
375
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
376
    }
Guolin Ke's avatar
Guolin Ke committed
377
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
378
  } else {
wxchan's avatar
wxchan committed
379
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
380
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
381
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
382
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
383
384
385
386
387
388
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
389
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
390
391
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
392
  *out = ret.release();
393
  API_END();
Guolin Ke's avatar
Guolin Ke committed
394
395
}

wxchan's avatar
wxchan committed
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
DllExport int LGBM_DatasetGetSubset(
  const DatesetHandle* handle,
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

416
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
417
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
418
  delete reinterpret_cast<Dataset*>(handle);
419
  API_END();
420
421
422
423
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
424
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
425
426
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
427
  API_END();
428
429
430
431
432
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
433
  int64_t num_element,
434
  int type) {
435
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
436
  auto dataset = reinterpret_cast<Dataset*>(handle);
437
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
438
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
439
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
440
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
441
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
442
  }
443
444
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
445
446
447
448
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
449
  int64_t* out_len,
450
451
  const void** out_ptr,
  int* out_type) {
452
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
453
  auto dataset = reinterpret_cast<Dataset*>(handle);
454
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
455
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
456
    *out_type = C_API_DTYPE_FLOAT32;
457
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
458
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
459
    *out_type = C_API_DTYPE_INT32;
460
    is_success = true;
461
  }
462
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
463
  if (*out_ptr == nullptr) { *out_len = 0; }
464
  API_END();
465
466
467
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
468
  int64_t* out) {
469
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
470
471
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
472
  API_END();
473
474
475
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
476
  int64_t* out) {
477
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
478
479
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
480
  API_END();
Guolin Ke's avatar
Guolin Ke committed
481
}
482
483
484
485
486
487
488


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
489
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
490
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
491
492
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
493
  API_END();
494
495
}

wxchan's avatar
wxchan committed
496
DllExport int LGBM_BoosterCreateFromModelfile(
497
  const char* filename,
wxchan's avatar
wxchan committed
498
  int64_t* out_num_iterations,
499
  BoosterHandle* out) {
500
  API_BEGIN();
wxchan's avatar
wxchan committed
501
502
503
504
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
505
  API_END();
506
507
508
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
509
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
510
  delete reinterpret_cast<Booster*>(handle);
511
  API_END();
512
513
}

wxchan's avatar
wxchan committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

555
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
556
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
557
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
558
559
560
561
562
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
563
  API_END();
564
565
566
567
568
569
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
570
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
571
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
572
573
574
575
576
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
577
  API_END();
578
579
}

wxchan's avatar
wxchan committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
618
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
619
  float* out_results) {
620
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
621
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
622
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
623
  auto result_buf = boosting->GetEvalAt(data_idx);
624
  *out_len = static_cast<int64_t>(result_buf.size());
625
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
626
    (out_results)[i] = static_cast<float>(result_buf[i]);
627
  }
628
  API_END();
629
630
}

Guolin Ke's avatar
Guolin Ke committed
631
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
632
  int data_idx,
633
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
634
  float* out_result) {
635
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
636
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
637
  int len = 0;
wxchan's avatar
wxchan committed
638
  ref_booster->GetPredictAt(data_idx, out_result, &len);
639
  *out_len = static_cast<int64_t>(len);
640
  API_END();
Guolin Ke's avatar
Guolin Ke committed
641
642
}

643
644
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
645
646
647
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
648
  const char* result_filename) {
649
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
650
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
651
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
652
653
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
654
  API_END();
655
656
}

657
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
658
659
  const void* indptr,
  int indptr_type,
660
661
  const int32_t* indices,
  const void* data,
662
663
664
665
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
666
  int predict_type,
wxchan's avatar
wxchan committed
667
668
669
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
670
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
671
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
672
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
673

674
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
675
676
677
678
679
680
681
682
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
683
684
685
686
687
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
688
689
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
690
691
    }
  }
wxchan's avatar
wxchan committed
692
  *out_len = nrow * num_preb_in_one_row;
693
  API_END();
Guolin Ke's avatar
Guolin Ke committed
694
}
695
696
697

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
698
  int data_type,
699
700
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
701
  int is_row_major,
702
  int predict_type,
wxchan's avatar
wxchan committed
703
704
705
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
706
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
707
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
708
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
709

710
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
711
712
713
714
715
716
717
718
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
719
720
721
722
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
723
724
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
725
726
    }
  }
wxchan's avatar
wxchan committed
727
  *out_len = nrow * num_preb_in_one_row;
728
  API_END();
Guolin Ke's avatar
Guolin Ke committed
729
}
730
731

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
732
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
733
  const char* filename) {
734
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
735
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
  *out_len = static_cast<int64_t>(model.size());
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
751
  API_END();
Guolin Ke's avatar
Guolin Ke committed
752
}
753

Guolin Ke's avatar
Guolin Ke committed
754
// ---- start of some help functions
755
756
757

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
758
  if (data_type == C_API_DTYPE_FLOAT32) {
759
760
761
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
762
        std::vector<double> ret(num_col);
763
764
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
765
          ret[i] = static_cast<double>(*(tmp_ptr + i));
766
767
768
769
770
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
771
        std::vector<double> ret(num_col);
772
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
773
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
774
775
776
777
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
778
  } else if (data_type == C_API_DTYPE_FLOAT64) {
779
780
781
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
782
        std::vector<double> ret(num_col);
783
784
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
785
          ret[i] = static_cast<double>(*(tmp_ptr + i));
786
787
788
789
790
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
791
        std::vector<double> ret(num_col);
792
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
793
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
794
795
796
797
798
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
799
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
800
801
802
803
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
804
805
806
807
808
809
810
811
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
812
        }
Guolin Ke's avatar
Guolin Ke committed
813
814
815
      }
      return ret;
    };
816
  }
Guolin Ke's avatar
Guolin Ke committed
817
  return nullptr;
818
819
820
821
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
822
  if (data_type == C_API_DTYPE_FLOAT32) {
823
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
824
    if (indptr_type == C_API_DTYPE_INT32) {
825
826
827
828
829
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
830
        for (int64_t i = start; i < end; ++i) {
831
832
833
834
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
835
    } else if (indptr_type == C_API_DTYPE_INT64) {
836
837
838
839
840
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
841
        for (int64_t i = start; i < end; ++i) {
842
843
844
845
846
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
847
  } else if (data_type == C_API_DTYPE_FLOAT64) {
848
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
849
    if (indptr_type == C_API_DTYPE_INT32) {
850
851
852
853
854
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
855
        for (int64_t i = start; i < end; ++i) {
856
857
858
859
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
860
    } else if (indptr_type == C_API_DTYPE_INT64) {
861
862
863
864
865
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
866
        for (int64_t i = start; i < end; ++i) {
867
868
869
870
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
871
872
873
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
874
875
876
877
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
878
  if (data_type == C_API_DTYPE_FLOAT32) {
879
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
880
    if (col_ptr_type == C_API_DTYPE_INT32) {
881
882
883
884
885
886
887
888
889
890
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
891
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
892
893
894
895
896
897
898
899
900
901
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
902
    } 
Guolin Ke's avatar
Guolin Ke committed
903
  } else if (data_type == C_API_DTYPE_FLOAT64) {
904
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
905
    if (col_ptr_type == C_API_DTYPE_INT32) {
906
907
908
909
910
911
912
913
914
915
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
916
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
917
918
919
920
921
922
923
924
925
926
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
927
928
929
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
930
931
}

932
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
933
934
935
936
937
938
939
940
941
942
943
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
944
}