"docs/_static/js/script.js" did not exist on "12257feb9bb88e3bb4e9f734ebe4253f2cd4c5e2"
c_api.cpp 32.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
22
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
23

Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
28
29
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
30
31
32
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
33
34
35
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
36
37
38
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
39
40
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
41
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
42
43
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
44

wxchan's avatar
wxchan committed
45
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
46

Guolin Ke's avatar
Guolin Ke committed
47
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
48
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
49
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
50
51

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
52
53
54
55
56
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
57
58
59
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
60

Guolin Ke's avatar
Guolin Ke committed
61
  }
62

wxchan's avatar
wxchan committed
63
64
65
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
wxchan's avatar
wxchan committed
88
89
90
91
92
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
94
95
96
97
98
99
100
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
101
102
103
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
104
105

    config_.Set(param);
106
107
108
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
116
117
118
119
120

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    
wxchan's avatar
wxchan committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
141

142
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
143
    std::lock_guard<std::mutex> lock(mutex_);
144
145
146
147
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
148
    std::lock_guard<std::mutex> lock(mutex_);
149
150
151
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
152
153
154
155
156
157
158
159
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

  void PrepareForPrediction(int num_iteration, int predict_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
160
161
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
162
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
163
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
164
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
165
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
166
167
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    }
Guolin Ke's avatar
Guolin Ke committed
169
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
170
171
  }

172
  void GetPredictAt(int data_idx, score_t* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
173
174
175
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
176
177
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
178
179
  }

180
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
Guolin Ke's avatar
Guolin Ke committed
181
    std::lock_guard<std::mutex> lock(mutex_);
182
183
184
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

wxchan's avatar
wxchan committed
185
186
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
187
  }
188

wxchan's avatar
wxchan committed
189
190
191
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
192

Guolin Ke's avatar
Guolin Ke committed
193
194
195
196
197
198
199
200
201
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
202
203
204
205
206
207
208
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
209

wxchan's avatar
wxchan committed
210
211
212
213
214
215
216
217
218
219
220
221
222
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
223
private:
224

wxchan's avatar
wxchan committed
225
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
226
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
227
228
229
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
230
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
231
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
232
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
233
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
234
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
235
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
236
  std::unique_ptr<Predictor> predictor_;
wxchan's avatar
wxchan committed
237
238
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
239
240
241
};

}
Guolin Ke's avatar
Guolin Ke committed
242
243
244

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
245
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
246
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
247
248
}

wxchan's avatar
wxchan committed
249
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
250
  const char* parameters,
251
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
252
  DatasetHandle* out) {
253
  API_BEGIN();
wxchan's avatar
wxchan committed
254
255
256
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
257
  DatasetLoader loader(io_config, nullptr, filename);
Guolin Ke's avatar
Guolin Ke committed
258
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
259
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
260
  } else {
Guolin Ke's avatar
Guolin Ke committed
261
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
262
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
263
  }
264
  API_END();
Guolin Ke's avatar
Guolin Ke committed
265
266
}

wxchan's avatar
wxchan committed
267
DllExport int LGBM_DatasetCreateFromMat(const void* data,
268
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
269
270
271
272
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
273
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
274
  DatasetHandle* out) {
275
  API_BEGIN();
wxchan's avatar
wxchan committed
276
277
278
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
279
  std::unique_ptr<Dataset> ret;
280
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
281
282
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
283
284
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
285
    auto sample_indices = rand.Sample(nrow, sample_cnt);
286
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
287
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
288
      auto idx = sample_indices[i];
289
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
290
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
291
292
293
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
294
295
      }
    }
Guolin Ke's avatar
Guolin Ke committed
296
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
297
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
298
  } else {
wxchan's avatar
wxchan committed
299
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
300
    ret->CopyFeatureMapperFrom(
301
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
302
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
303
304
305
306
307
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
308
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
309
310
311
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
312
  *out = ret.release();
313
  API_END();
314
315
}

wxchan's avatar
wxchan committed
316
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
317
  int indptr_type,
318
319
  const int32_t* indices,
  const void* data,
320
321
322
323
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
324
  const char* parameters,
325
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
326
  DatasetHandle* out) {
327
  API_BEGIN();
wxchan's avatar
wxchan committed
328
329
330
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
331
  std::unique_ptr<Dataset> ret;
332
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
333
334
335
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
336
337
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
338
339
340
341
342
343
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
344
345
346
347
348
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
349
          }
350
351
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
352
353
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
354
355
356
        }
      }
    }
357
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
358
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
359
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
360
  } else {
wxchan's avatar
wxchan committed
361
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
362
    ret->CopyFeatureMapperFrom(
363
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
364
      io_config.is_enable_sparse);
365
366
367
368
369
370
371
372
373
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
374
  *out = ret.release();
375
  API_END();
376
377
}

wxchan's avatar
wxchan committed
378
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
379
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
380
381
  const int32_t* indices,
  const void* data,
382
383
384
385
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
386
  const char* parameters,
387
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
388
  DatasetHandle* out) {
389
  API_BEGIN();
wxchan's avatar
wxchan committed
390
391
392
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
393
  std::unique_ptr<Dataset> ret;
394
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
395
396
397
398
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
399
400
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
405
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
406
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
407
    }
Guolin Ke's avatar
Guolin Ke committed
408
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
409
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
410
  } else {
wxchan's avatar
wxchan committed
411
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
412
    ret->CopyFeatureMapperFrom(
413
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
414
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
415
416
417
418
419
420
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
421
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
422
423
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
424
  *out = ret.release();
425
  API_END();
Guolin Ke's avatar
Guolin Ke committed
426
427
}

wxchan's avatar
wxchan committed
428
DllExport int LGBM_DatasetGetSubset(
429
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
430
431
432
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
433
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
434
435
436
437
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
438
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
439
440
441
442
443
444
445
446
447
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
448
DllExport int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
449
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
450
451
452
453
454
455
456
457
458
459
460
461
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
typo  
Guolin Ke committed
462
DllExport int LGBM_DatasetFree(DatasetHandle handle) {
463
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
464
  delete reinterpret_cast<Dataset*>(handle);
465
  API_END();
466
467
}

Guolin Ke's avatar
typo  
Guolin Ke committed
468
DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
469
  const char* filename) {
470
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
471
472
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
473
  API_END();
474
475
}

Guolin Ke's avatar
typo  
Guolin Ke committed
476
DllExport int LGBM_DatasetSetField(DatasetHandle handle,
477
478
  const char* field_name,
  const void* field_data,
479
  int64_t num_element,
480
  int type) {
481
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
482
  auto dataset = reinterpret_cast<Dataset*>(handle);
483
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
484
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
485
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
486
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
487
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
488
  }
489
490
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
491
492
}

Guolin Ke's avatar
typo  
Guolin Ke committed
493
DllExport int LGBM_DatasetGetField(DatasetHandle handle,
494
  const char* field_name,
495
  int64_t* out_len,
496
497
  const void** out_ptr,
  int* out_type) {
498
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
499
  auto dataset = reinterpret_cast<Dataset*>(handle);
500
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
501
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
502
    *out_type = C_API_DTYPE_FLOAT32;
503
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
504
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
505
    *out_type = C_API_DTYPE_INT32;
506
    is_success = true;
507
  }
508
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
509
  if (*out_ptr == nullptr) { *out_len = 0; }
510
  API_END();
511
512
}

Guolin Ke's avatar
typo  
Guolin Ke committed
513
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
514
  int64_t* out) {
515
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
516
517
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
518
  API_END();
519
520
}

Guolin Ke's avatar
typo  
Guolin Ke committed
521
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
522
  int64_t* out) {
523
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
524
525
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
526
  API_END();
Guolin Ke's avatar
Guolin Ke committed
527
}
528
529
530
531


// ---- start of booster

Guolin Ke's avatar
typo  
Guolin Ke committed
532
DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
533
534
  const char* parameters,
  BoosterHandle* out) {
535
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
536
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
537
538
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
539
  API_END();
540
541
}

wxchan's avatar
wxchan committed
542
DllExport int LGBM_BoosterCreateFromModelfile(
543
  const char* filename,
wxchan's avatar
wxchan committed
544
  int64_t* out_num_iterations,
545
  BoosterHandle* out) {
546
  API_BEGIN();
wxchan's avatar
wxchan committed
547
548
549
550
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
551
  API_END();
552
553
554
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
555
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
556
  delete reinterpret_cast<Booster*>(handle);
557
  API_END();
558
559
}

wxchan's avatar
wxchan committed
560
561
562
563
564
565
566
567
568
569
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
570
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
571
572
573
574
575
576
577
578
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
579
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

601
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
602
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
603
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
604
605
606
607
608
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
609
  API_END();
610
611
612
613
614
615
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
616
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
617
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
618
619
620
621
622
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
623
  API_END();
624
625
}

wxchan's avatar
wxchan committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
664
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
665
  float* out_results) {
666
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
667
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
668
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
669
  auto result_buf = boosting->GetEvalAt(data_idx);
670
  *out_len = static_cast<int64_t>(result_buf.size());
671
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
672
    (out_results)[i] = static_cast<float>(result_buf[i]);
673
  }
674
  API_END();
675
676
}

Guolin Ke's avatar
Guolin Ke committed
677
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
678
  int data_idx,
679
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
680
  float* out_result) {
681
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
682
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
683
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
684
  API_END();
Guolin Ke's avatar
Guolin Ke committed
685
686
}

687
688
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
689
690
691
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
692
  const char* result_filename) {
693
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
694
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
695
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
696
697
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
698
  API_END();
699
700
}

701
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
702
703
  const void* indptr,
  int indptr_type,
704
705
  const int32_t* indices,
  const void* data,
706
707
708
709
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
710
  int predict_type,
wxchan's avatar
wxchan committed
711
712
713
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
714
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
715
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
716
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
717

718
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
719
720
721
722
723
724
725
726
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
727
728
729
730
731
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
732
733
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
734
735
    }
  }
wxchan's avatar
wxchan committed
736
  *out_len = nrow * num_preb_in_one_row;
737
  API_END();
Guolin Ke's avatar
Guolin Ke committed
738
}
739
740
741

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
742
  int data_type,
743
744
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
745
  int is_row_major,
746
  int predict_type,
wxchan's avatar
wxchan committed
747
748
749
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
750
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
751
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
752
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
753

754
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
755
756
757
758
759
760
761
762
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
763
764
765
766
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
767
768
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
769
770
    }
  }
wxchan's avatar
wxchan committed
771
  *out_len = nrow * num_preb_in_one_row;
772
  API_END();
Guolin Ke's avatar
Guolin Ke committed
773
}
774
775

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
776
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
777
  const char* filename) {
778
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
779
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
780
781
782
783
784
785
786
787
788
789
790
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
791
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
792
793
794
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
795
  API_END();
Guolin Ke's avatar
Guolin Ke committed
796
}
797

Guolin Ke's avatar
Guolin Ke committed
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820


DllExport int LGBM_BoosterGetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
  float* out_val) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_val = static_cast<float>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
  API_END();
}


DllExport int LGBM_BoosterSetLeafValue(BoosterHandle handle,
  int tree_idx,
  int leaf_idx,
  float val) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->SetLeafValue(tree_idx, leaf_idx, static_cast<double>(val));
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
821
// ---- start of some help functions
822
823
824

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
825
  if (data_type == C_API_DTYPE_FLOAT32) {
826
827
828
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
829
        std::vector<double> ret(num_col);
830
831
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
832
          ret[i] = static_cast<double>(*(tmp_ptr + i));
833
834
835
836
837
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
838
        std::vector<double> ret(num_col);
839
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
840
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
841
842
843
844
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
845
  } else if (data_type == C_API_DTYPE_FLOAT64) {
846
847
848
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
849
        std::vector<double> ret(num_col);
850
851
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
852
          ret[i] = static_cast<double>(*(tmp_ptr + i));
853
854
855
856
857
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
858
        std::vector<double> ret(num_col);
859
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
860
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
861
862
863
864
865
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
866
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
867
868
869
870
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
871
872
873
874
875
876
877
878
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
879
        }
Guolin Ke's avatar
Guolin Ke committed
880
881
882
      }
      return ret;
    };
883
  }
Guolin Ke's avatar
Guolin Ke committed
884
  return nullptr;
885
886
887
888
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
889
  if (data_type == C_API_DTYPE_FLOAT32) {
890
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
891
    if (indptr_type == C_API_DTYPE_INT32) {
892
893
894
895
896
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
897
        for (int64_t i = start; i < end; ++i) {
898
899
900
901
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
902
    } else if (indptr_type == C_API_DTYPE_INT64) {
903
904
905
906
907
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
908
        for (int64_t i = start; i < end; ++i) {
909
910
911
912
913
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
914
  } else if (data_type == C_API_DTYPE_FLOAT64) {
915
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
916
    if (indptr_type == C_API_DTYPE_INT32) {
917
918
919
920
921
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
922
        for (int64_t i = start; i < end; ++i) {
923
924
925
926
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
927
    } else if (indptr_type == C_API_DTYPE_INT64) {
928
929
930
931
932
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
933
        for (int64_t i = start; i < end; ++i) {
934
935
936
937
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
938
939
940
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
941
942
943
944
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
945
  if (data_type == C_API_DTYPE_FLOAT32) {
946
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
947
    if (col_ptr_type == C_API_DTYPE_INT32) {
948
949
950
951
952
953
954
955
956
957
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
958
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
959
960
961
962
963
964
965
966
967
968
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
969
    } 
Guolin Ke's avatar
Guolin Ke committed
970
  } else if (data_type == C_API_DTYPE_FLOAT64) {
971
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
972
    if (col_ptr_type == C_API_DTYPE_INT32) {
973
974
975
976
977
978
979
980
981
982
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
983
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
984
985
986
987
988
989
990
991
992
993
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
994
995
996
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
997
998
}

999
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
1011
}