c_api.cpp 32 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
32
33
34
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
35
36
37
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
38
39
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
40
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
41
42
        please use continued train with input score");
    }
wxchan's avatar
wxchan committed
43
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
44
    train_data_ = train_data;
wxchan's avatar
wxchan committed
45
    ConstructObjectAndTrainingMetrics(train_data);
Guolin Ke's avatar
Guolin Ke committed
46
    // initialize the boosting
wxchan's avatar
wxchan committed
47
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
48
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
49
50
51
52
53
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
54
55
56
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
57

Guolin Ke's avatar
Guolin Ke committed
58
  }
59

wxchan's avatar
wxchan committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
    // initialize the boosting
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
77
78
79
80
81
82
83
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
    {
      std::lock_guard<std::mutex> lock(mutex_);
      config_.Set(param);
    }
84
    if (config_.num_threads > 0) {
Guolin Ke's avatar
Guolin Ke committed
85
      std::lock_guard<std::mutex> lock(mutex_);
86
87
      omp_set_num_threads(config_.num_threads);
    }
wxchan's avatar
wxchan committed
88
89
    if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
      // only need to set learning rate
Guolin Ke's avatar
Guolin Ke committed
90
      std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
      boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
    } else {
      ResetTrainingData(train_data_);
    }
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
110
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
111
    std::lock_guard<std::mutex> lock(mutex_);
112
113
114
115
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
116
    std::lock_guard<std::mutex> lock(mutex_);
117
118
119
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
120
121
122
123
124
125
126
127
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

  void PrepareForPrediction(int num_iteration, int predict_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
128
129
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
130
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
131
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
132
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
133
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
134
135
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
136
    }
Guolin Ke's avatar
Guolin Ke committed
137
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
138
139
  }

wxchan's avatar
wxchan committed
140
141
142
143
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
144
145
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
146
147
  }

148
149
150
151
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

wxchan's avatar
wxchan committed
152
153
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
154
  }
155

wxchan's avatar
wxchan committed
156
157
158
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
159

wxchan's avatar
wxchan committed
160
161
162
163
164
165
166
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
167

wxchan's avatar
wxchan committed
168
169
170
171
172
173
174
175
176
177
178
179
180
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
181
private:
182

wxchan's avatar
wxchan committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
207
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
208
209
210
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
211
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
212
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
213
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
214
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
215
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
216
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
217
  std::unique_ptr<Predictor> predictor_;
wxchan's avatar
wxchan committed
218
219
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
220
221
222
};

}
Guolin Ke's avatar
Guolin Ke committed
223
224
225

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
226
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
227
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
228
229
}

wxchan's avatar
wxchan committed
230
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
231
232
233
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
234
  API_BEGIN();
wxchan's avatar
wxchan committed
235
236
237
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
238
  DatasetLoader loader(io_config, nullptr, filename);
Guolin Ke's avatar
Guolin Ke committed
239
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
240
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
241
  } else {
Guolin Ke's avatar
Guolin Ke committed
242
243
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
244
  }
245
  API_END();
Guolin Ke's avatar
Guolin Ke committed
246
247
}

wxchan's avatar
wxchan committed
248
DllExport int LGBM_DatasetCreateFromMat(const void* data,
249
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
254
255
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
256
  API_BEGIN();
wxchan's avatar
wxchan committed
257
258
259
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
260
  std::unique_ptr<Dataset> ret;
261
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
262
263
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
264
265
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
266
    auto sample_indices = rand.Sample(nrow, sample_cnt);
267
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
268
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
269
      auto idx = sample_indices[i];
270
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
271
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
272
273
274
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
275
276
      }
    }
Guolin Ke's avatar
Guolin Ke committed
277
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
278
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
279
  } else {
wxchan's avatar
wxchan committed
280
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
281
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
282
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
283
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
289
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
290
291
292
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
293
  *out = ret.release();
294
  API_END();
295
296
}

wxchan's avatar
wxchan committed
297
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
298
  int indptr_type,
299
300
  const int32_t* indices,
  const void* data,
301
302
303
304
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
305
306
307
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
308
  API_BEGIN();
wxchan's avatar
wxchan committed
309
310
311
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
312
  std::unique_ptr<Dataset> ret;
313
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
314
315
316
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
317
318
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
319
320
321
322
323
324
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328
329
330
331
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
332
          }
Guolin Ke's avatar
Guolin Ke committed
333
334
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
335
336
337
        }
      }
    }
338
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
339
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
340
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
341
  } else {
wxchan's avatar
wxchan committed
342
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
343
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
344
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
345
      io_config.is_enable_sparse);
346
347
348
349
350
351
352
353
354
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
355
  *out = ret.release();
356
  API_END();
357
358
}

wxchan's avatar
wxchan committed
359
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
360
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
361
362
  const int32_t* indices,
  const void* data,
363
364
365
366
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
367
368
369
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
370
  API_BEGIN();
wxchan's avatar
wxchan committed
371
372
373
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
374
  std::unique_ptr<Dataset> ret;
375
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
376
377
378
379
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
380
381
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385
386
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
387
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
388
    }
Guolin Ke's avatar
Guolin Ke committed
389
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
390
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
391
  } else {
wxchan's avatar
wxchan committed
392
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
393
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
394
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
395
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
396
397
398
399
400
401
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
402
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
403
404
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
405
  *out = ret.release();
406
  API_END();
Guolin Ke's avatar
Guolin Ke committed
407
408
}

wxchan's avatar
wxchan committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
DllExport int LGBM_DatasetGetSubset(
  const DatesetHandle* handle,
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
DllExport int LGBM_DatasetSetFeatureNames(
  DatesetHandle handle,
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

443
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
444
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
445
  delete reinterpret_cast<Dataset*>(handle);
446
  API_END();
447
448
449
450
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
451
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
452
453
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
454
  API_END();
455
456
457
458
459
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
460
  int64_t num_element,
461
  int type) {
462
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
463
  auto dataset = reinterpret_cast<Dataset*>(handle);
464
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
465
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
466
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
467
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
468
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
469
  }
470
471
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
472
473
474
475
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
476
  int64_t* out_len,
477
478
  const void** out_ptr,
  int* out_type) {
479
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
480
  auto dataset = reinterpret_cast<Dataset*>(handle);
481
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
482
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
483
    *out_type = C_API_DTYPE_FLOAT32;
484
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
485
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
486
    *out_type = C_API_DTYPE_INT32;
487
    is_success = true;
488
  }
489
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
490
  if (*out_ptr == nullptr) { *out_len = 0; }
491
  API_END();
492
493
494
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
495
  int64_t* out) {
496
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
497
498
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
499
  API_END();
500
501
502
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
503
  int64_t* out) {
504
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
505
506
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
507
  API_END();
Guolin Ke's avatar
Guolin Ke committed
508
}
509
510
511
512
513
514
515


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
516
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
517
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
518
519
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
520
  API_END();
521
522
}

wxchan's avatar
wxchan committed
523
DllExport int LGBM_BoosterCreateFromModelfile(
524
  const char* filename,
wxchan's avatar
wxchan committed
525
  int64_t* out_num_iterations,
526
  BoosterHandle* out) {
527
  API_BEGIN();
wxchan's avatar
wxchan committed
528
529
530
531
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
532
  API_END();
533
534
535
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
536
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
537
  delete reinterpret_cast<Booster*>(handle);
538
  API_END();
539
540
}

wxchan's avatar
wxchan committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

582
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
583
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
584
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
585
586
587
588
589
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
590
  API_END();
591
592
593
594
595
596
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
597
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
598
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
599
600
601
602
603
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
604
  API_END();
605
606
}

wxchan's avatar
wxchan committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
645
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
646
  float* out_results) {
647
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
648
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
649
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
650
  auto result_buf = boosting->GetEvalAt(data_idx);
651
  *out_len = static_cast<int64_t>(result_buf.size());
652
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
653
    (out_results)[i] = static_cast<float>(result_buf[i]);
654
  }
655
  API_END();
656
657
}

Guolin Ke's avatar
Guolin Ke committed
658
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
659
  int data_idx,
660
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
661
  float* out_result) {
662
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
663
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
664
  int len = 0;
wxchan's avatar
wxchan committed
665
  ref_booster->GetPredictAt(data_idx, out_result, &len);
666
  *out_len = static_cast<int64_t>(len);
667
  API_END();
Guolin Ke's avatar
Guolin Ke committed
668
669
}

670
671
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
672
673
674
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
675
  const char* result_filename) {
676
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
677
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
678
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
679
680
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
681
  API_END();
682
683
}

684
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
685
686
  const void* indptr,
  int indptr_type,
687
688
  const int32_t* indices,
  const void* data,
689
690
691
692
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
693
  int predict_type,
wxchan's avatar
wxchan committed
694
695
696
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
697
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
698
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
699
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
700

701
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
702
703
704
705
706
707
708
709
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
710
711
712
713
714
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
715
716
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
717
718
    }
  }
wxchan's avatar
wxchan committed
719
  *out_len = nrow * num_preb_in_one_row;
720
  API_END();
Guolin Ke's avatar
Guolin Ke committed
721
}
722
723
724

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
725
  int data_type,
726
727
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
728
  int is_row_major,
729
  int predict_type,
wxchan's avatar
wxchan committed
730
731
732
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
733
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
734
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
735
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
736

737
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
738
739
740
741
742
743
744
745
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
746
747
748
749
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
750
751
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
752
753
    }
  }
wxchan's avatar
wxchan committed
754
  *out_len = nrow * num_preb_in_one_row;
755
  API_END();
Guolin Ke's avatar
Guolin Ke committed
756
}
757
758

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
759
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
760
  const char* filename) {
761
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
762
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
763
764
765
766
767
768
769
770
771
772
773
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
774
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
775
776
777
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
778
  API_END();
Guolin Ke's avatar
Guolin Ke committed
779
}
780

Guolin Ke's avatar
Guolin Ke committed
781
// ---- start of some help functions
782
783
784

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
785
  if (data_type == C_API_DTYPE_FLOAT32) {
786
787
788
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
789
        std::vector<double> ret(num_col);
790
791
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
792
          ret[i] = static_cast<double>(*(tmp_ptr + i));
793
794
795
796
797
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
798
        std::vector<double> ret(num_col);
799
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
800
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
801
802
803
804
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
805
  } else if (data_type == C_API_DTYPE_FLOAT64) {
806
807
808
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
809
        std::vector<double> ret(num_col);
810
811
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
812
          ret[i] = static_cast<double>(*(tmp_ptr + i));
813
814
815
816
817
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
818
        std::vector<double> ret(num_col);
819
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
820
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
821
822
823
824
825
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
826
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
827
828
829
830
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
831
832
833
834
835
836
837
838
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
839
        }
Guolin Ke's avatar
Guolin Ke committed
840
841
842
      }
      return ret;
    };
843
  }
Guolin Ke's avatar
Guolin Ke committed
844
  return nullptr;
845
846
847
848
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
849
  if (data_type == C_API_DTYPE_FLOAT32) {
850
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
851
    if (indptr_type == C_API_DTYPE_INT32) {
852
853
854
855
856
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
857
        for (int64_t i = start; i < end; ++i) {
858
859
860
861
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
862
    } else if (indptr_type == C_API_DTYPE_INT64) {
863
864
865
866
867
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
868
        for (int64_t i = start; i < end; ++i) {
869
870
871
872
873
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
874
  } else if (data_type == C_API_DTYPE_FLOAT64) {
875
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
876
    if (indptr_type == C_API_DTYPE_INT32) {
877
878
879
880
881
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
882
        for (int64_t i = start; i < end; ++i) {
883
884
885
886
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
887
    } else if (indptr_type == C_API_DTYPE_INT64) {
888
889
890
891
892
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
893
        for (int64_t i = start; i < end; ++i) {
894
895
896
897
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
898
899
900
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
901
902
903
904
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
905
  if (data_type == C_API_DTYPE_FLOAT32) {
906
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
907
    if (col_ptr_type == C_API_DTYPE_INT32) {
908
909
910
911
912
913
914
915
916
917
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
918
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
919
920
921
922
923
924
925
926
927
928
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
929
    } 
Guolin Ke's avatar
Guolin Ke committed
930
  } else if (data_type == C_API_DTYPE_FLOAT64) {
931
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
932
    if (col_ptr_type == C_API_DTYPE_INT32) {
933
934
935
936
937
938
939
940
941
942
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
943
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
944
945
946
947
948
949
950
951
952
953
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
954
955
956
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
957
958
}

959
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
960
961
962
963
964
965
966
967
968
969
970
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
971
}