"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "823fc03cc820758fceae91eb33851e1534b65360"
c_api.cpp 31.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
32
33
34
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
35
36
37
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
38
39
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
40
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
41
42
        please use continued train with input score");
    }
wxchan's avatar
wxchan committed
43
44
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
    ConstructObjectAndTrainingMetrics(train_data);
Guolin Ke's avatar
Guolin Ke committed
45
    // initialize the boosting
wxchan's avatar
wxchan committed
46
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
47
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
48
49
50
51
52
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
53
54
55
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
56

Guolin Ke's avatar
Guolin Ke committed
57
  }
58

wxchan's avatar
wxchan committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
    // initialize the boosting
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
78
79
80
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
wxchan's avatar
wxchan committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
      // only need to set learning rate
      boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
    } else {
      ResetTrainingData(train_data_);
    }
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
102
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
103
    std::lock_guard<std::mutex> lock(mutex_);
104
105
106
107
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
108
    std::lock_guard<std::mutex> lock(mutex_);
109
110
111
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
112
113
114
115
116
117
118
119
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

  void PrepareForPrediction(int num_iteration, int predict_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
120
121
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
122
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
123
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
124
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
125
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
126
127
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
130
131
  }

wxchan's avatar
wxchan committed
132
133
134
135
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
136
137
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
138
139
  }

140
141
142
143
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

wxchan's avatar
wxchan committed
144
145
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
146
  }
147

wxchan's avatar
wxchan committed
148
149
150
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
151

wxchan's avatar
wxchan committed
152
153
154
155
156
157
158
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
159

wxchan's avatar
wxchan committed
160
161
162
163
164
165
166
167
168
169
170
171
172
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
173
private:
174

wxchan's avatar
wxchan committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
199
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
200
201
202
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
203
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
204
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
205
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
206
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
207
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
208
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
209
  std::unique_ptr<Predictor> predictor_;
wxchan's avatar
wxchan committed
210
211
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
212
213
214
};

}
Guolin Ke's avatar
Guolin Ke committed
215
216
217

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
218
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
219
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
220
221
}

wxchan's avatar
wxchan committed
222
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
223
224
225
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
226
  API_BEGIN();
wxchan's avatar
wxchan committed
227
228
229
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
230
  DatasetLoader loader(io_config, nullptr, filename);
Guolin Ke's avatar
Guolin Ke committed
231
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
232
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
233
  } else {
Guolin Ke's avatar
Guolin Ke committed
234
235
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
236
  }
237
  API_END();
Guolin Ke's avatar
Guolin Ke committed
238
239
}

wxchan's avatar
wxchan committed
240
DllExport int LGBM_DatasetCreateFromMat(const void* data,
241
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
242
243
244
245
246
247
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
248
  API_BEGIN();
wxchan's avatar
wxchan committed
249
250
251
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
252
  std::unique_ptr<Dataset> ret;
253
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
254
255
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
256
257
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
258
    auto sample_indices = rand.Sample(nrow, sample_cnt);
259
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
260
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
261
      auto idx = sample_indices[i];
262
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
263
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
264
265
266
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
267
268
      }
    }
Guolin Ke's avatar
Guolin Ke committed
269
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
270
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
271
  } else {
wxchan's avatar
wxchan committed
272
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
273
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
274
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
275
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
280
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
281
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
282
283
284
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
285
  *out = ret.release();
286
  API_END();
287
288
}

wxchan's avatar
wxchan committed
289
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
290
  int indptr_type,
291
292
  const int32_t* indices,
  const void* data,
293
294
295
296
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
297
298
299
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
300
  API_BEGIN();
wxchan's avatar
wxchan committed
301
302
303
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
304
  std::unique_ptr<Dataset> ret;
305
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
306
307
308
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
309
310
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
311
312
313
314
315
316
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
317
318
319
320
321
322
323
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
324
          }
Guolin Ke's avatar
Guolin Ke committed
325
326
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
327
328
329
        }
      }
    }
330
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
331
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
332
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
333
  } else {
wxchan's avatar
wxchan committed
334
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
335
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
336
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
337
      io_config.is_enable_sparse);
338
339
340
341
342
343
344
345
346
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
347
  *out = ret.release();
348
  API_END();
349
350
}

wxchan's avatar
wxchan committed
351
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
352
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
353
354
  const int32_t* indices,
  const void* data,
355
356
357
358
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
359
360
361
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
362
  API_BEGIN();
wxchan's avatar
wxchan committed
363
364
365
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
366
  std::unique_ptr<Dataset> ret;
367
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
368
369
370
371
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
372
373
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
374
375
376
377
378
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
379
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
380
    }
Guolin Ke's avatar
Guolin Ke committed
381
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
382
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
383
  } else {
wxchan's avatar
wxchan committed
384
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
385
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
386
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
387
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
388
389
390
391
392
393
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
394
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
395
396
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
397
  *out = ret.release();
398
  API_END();
Guolin Ke's avatar
Guolin Ke committed
399
400
}

wxchan's avatar
wxchan committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
DllExport int LGBM_DatasetGetSubset(
  const DatesetHandle* handle,
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
DllExport int LGBM_DatasetSetFeatureNames(
  DatesetHandle handle,
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

435
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
436
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
437
  delete reinterpret_cast<Dataset*>(handle);
438
  API_END();
439
440
441
442
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
443
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
444
445
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
446
  API_END();
447
448
449
450
451
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
452
  int64_t num_element,
453
  int type) {
454
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
455
  auto dataset = reinterpret_cast<Dataset*>(handle);
456
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
457
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
458
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
459
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
460
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
461
  }
462
463
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
464
465
466
467
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
468
  int64_t* out_len,
469
470
  const void** out_ptr,
  int* out_type) {
471
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
472
  auto dataset = reinterpret_cast<Dataset*>(handle);
473
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
474
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
475
    *out_type = C_API_DTYPE_FLOAT32;
476
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
477
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
478
    *out_type = C_API_DTYPE_INT32;
479
    is_success = true;
480
  }
481
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
482
  if (*out_ptr == nullptr) { *out_len = 0; }
483
  API_END();
484
485
486
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
487
  int64_t* out) {
488
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
489
490
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
491
  API_END();
492
493
494
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
495
  int64_t* out) {
496
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
497
498
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
499
  API_END();
Guolin Ke's avatar
Guolin Ke committed
500
}
501
502
503
504
505
506
507


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
508
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
509
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
510
511
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
512
  API_END();
513
514
}

wxchan's avatar
wxchan committed
515
DllExport int LGBM_BoosterCreateFromModelfile(
516
  const char* filename,
wxchan's avatar
wxchan committed
517
  int64_t* out_num_iterations,
518
  BoosterHandle* out) {
519
  API_BEGIN();
wxchan's avatar
wxchan committed
520
521
522
523
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
524
  API_END();
525
526
527
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
528
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
529
  delete reinterpret_cast<Booster*>(handle);
530
  API_END();
531
532
}

wxchan's avatar
wxchan committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

574
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
575
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
576
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
577
578
579
580
581
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
582
  API_END();
583
584
585
586
587
588
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
589
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
590
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
591
592
593
594
595
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
596
  API_END();
597
598
}

wxchan's avatar
wxchan committed
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
637
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
638
  float* out_results) {
639
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
640
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
641
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
642
  auto result_buf = boosting->GetEvalAt(data_idx);
643
  *out_len = static_cast<int64_t>(result_buf.size());
644
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
645
    (out_results)[i] = static_cast<float>(result_buf[i]);
646
  }
647
  API_END();
648
649
}

Guolin Ke's avatar
Guolin Ke committed
650
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
651
  int data_idx,
652
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
653
  float* out_result) {
654
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
655
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
656
  int len = 0;
wxchan's avatar
wxchan committed
657
  ref_booster->GetPredictAt(data_idx, out_result, &len);
658
  *out_len = static_cast<int64_t>(len);
659
  API_END();
Guolin Ke's avatar
Guolin Ke committed
660
661
}

662
663
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
664
665
666
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
667
  const char* result_filename) {
668
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
669
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
670
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
671
672
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
673
  API_END();
674
675
}

676
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
677
678
  const void* indptr,
  int indptr_type,
679
680
  const int32_t* indices,
  const void* data,
681
682
683
684
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
685
  int predict_type,
wxchan's avatar
wxchan committed
686
687
688
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
689
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
690
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
691
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
692

693
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
694
695
696
697
698
699
700
701
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
702
703
704
705
706
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
707
708
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
709
710
    }
  }
wxchan's avatar
wxchan committed
711
  *out_len = nrow * num_preb_in_one_row;
712
  API_END();
Guolin Ke's avatar
Guolin Ke committed
713
}
714
715
716

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
717
  int data_type,
718
719
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
720
  int is_row_major,
721
  int predict_type,
wxchan's avatar
wxchan committed
722
723
724
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
725
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
726
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
727
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
728

729
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
730
731
732
733
734
735
736
737
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
738
739
740
741
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
742
743
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
744
745
    }
  }
wxchan's avatar
wxchan committed
746
  *out_len = nrow * num_preb_in_one_row;
747
  API_END();
Guolin Ke's avatar
Guolin Ke committed
748
}
749
750

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
751
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
752
  const char* filename) {
753
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
754
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
755
756
757
758
759
760
761
762
763
764
765
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
766
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
767
768
769
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
770
  API_END();
Guolin Ke's avatar
Guolin Ke committed
771
}
772

Guolin Ke's avatar
Guolin Ke committed
773
// ---- start of some help functions
774
775
776

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
777
  if (data_type == C_API_DTYPE_FLOAT32) {
778
779
780
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
781
        std::vector<double> ret(num_col);
782
783
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
784
          ret[i] = static_cast<double>(*(tmp_ptr + i));
785
786
787
788
789
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
790
        std::vector<double> ret(num_col);
791
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
792
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
793
794
795
796
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
797
  } else if (data_type == C_API_DTYPE_FLOAT64) {
798
799
800
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
801
        std::vector<double> ret(num_col);
802
803
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
804
          ret[i] = static_cast<double>(*(tmp_ptr + i));
805
806
807
808
809
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
810
        std::vector<double> ret(num_col);
811
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
812
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
813
814
815
816
817
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
818
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
819
820
821
822
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
823
824
825
826
827
828
829
830
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
831
        }
Guolin Ke's avatar
Guolin Ke committed
832
833
834
      }
      return ret;
    };
835
  }
Guolin Ke's avatar
Guolin Ke committed
836
  return nullptr;
837
838
839
840
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
841
  if (data_type == C_API_DTYPE_FLOAT32) {
842
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
843
    if (indptr_type == C_API_DTYPE_INT32) {
844
845
846
847
848
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
849
        for (int64_t i = start; i < end; ++i) {
850
851
852
853
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
854
    } else if (indptr_type == C_API_DTYPE_INT64) {
855
856
857
858
859
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
860
        for (int64_t i = start; i < end; ++i) {
861
862
863
864
865
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
866
  } else if (data_type == C_API_DTYPE_FLOAT64) {
867
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
868
    if (indptr_type == C_API_DTYPE_INT32) {
869
870
871
872
873
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
874
        for (int64_t i = start; i < end; ++i) {
875
876
877
878
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
879
    } else if (indptr_type == C_API_DTYPE_INT64) {
880
881
882
883
884
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
885
        for (int64_t i = start; i < end; ++i) {
886
887
888
889
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
890
891
892
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
893
894
895
896
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
897
  if (data_type == C_API_DTYPE_FLOAT32) {
898
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
899
    if (col_ptr_type == C_API_DTYPE_INT32) {
900
901
902
903
904
905
906
907
908
909
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
910
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
911
912
913
914
915
916
917
918
919
920
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
921
    } 
Guolin Ke's avatar
Guolin Ke committed
922
  } else if (data_type == C_API_DTYPE_FLOAT64) {
923
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
924
    if (col_ptr_type == C_API_DTYPE_INT32) {
925
926
927
928
929
930
931
932
933
934
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
935
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
936
937
938
939
940
941
942
943
944
945
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
946
947
948
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
949
950
}

951
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
952
953
954
955
956
957
958
959
960
961
962
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
963
}