"vscode:/vscode.git/clone" did not exist on "19de2be0deb388a0cafc0aef90de0d1c9da49812"
c_api.cpp 30.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
33
    std::unique_lock<std::mutex> lock(mutex_);
34
35
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
36
37
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
38
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
39
40
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
41
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
42
43
44
45
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
46
    lock.unlock();
47
48
  }

Guolin Ke's avatar
Guolin Ke committed
49
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
50
    std::unique_lock<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
51
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
52
    lock.unlock();
Guolin Ke's avatar
Guolin Ke committed
53
54
  }

55
56
57
58
59
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
60
    std::unique_lock<std::mutex> lock(mutex_);
61
62
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
63
    // initialize the boosting
64
65
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
66
    lock.unlock();
67
68
69
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
70
    std::unique_lock<std::mutex> lock(mutex_);
71
72
73
74
75
76
77
78
79
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
Guolin Ke's avatar
Guolin Ke committed
80
    lock.unlock();
81
  }
Guolin Ke's avatar
Guolin Ke committed
82

83
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
84
    std::unique_lock<std::mutex> lock(mutex_);
85
86
87
88
89
90
91
92
93
94
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
95
    lock.unlock();
Guolin Ke's avatar
Guolin Ke committed
96
  }
97
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
    std::unique_lock<std::mutex> lock(mutex_);
    bool ret = boosting_->TrainOneIter(nullptr, nullptr, false);
    lock.unlock();
    return ret;
102
103
104
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
    std::unique_lock<std::mutex> lock(mutex_);
    bool ret = boosting_->TrainOneIter(gradients, hessians, false);
    lock.unlock();
    return ret;
109
110
  }

Guolin Ke's avatar
Guolin Ke committed
111
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
112
    std::unique_lock<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
113
    boosting_->RollbackOneIter();
Guolin Ke's avatar
Guolin Ke committed
114
    lock.unlock();
Guolin Ke's avatar
Guolin Ke committed
115
116
  }

Guolin Ke's avatar
Guolin Ke committed
117
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
118
    std::unique_lock<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
119
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
120
121
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
122
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
123
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
124
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
125
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
126
127
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
130
    lock.unlock();
Guolin Ke's avatar
Guolin Ke committed
131
132
  }

Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
137
138
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
139
140
  }

141
142
143
144
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
145
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
    std::unique_lock<std::mutex> lock(mutex_);
    boosting_->SaveModelToFile(num_iteration, filename);
    lock.unlock();
Guolin Ke's avatar
Guolin Ke committed
149
  }
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
156
157
158

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
159
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
160
161
162
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
163
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
164
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
170
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
171
  
Guolin Ke's avatar
Guolin Ke committed
172
private:
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

197
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
198
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
199
200
201
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
202
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
203
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
204
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
205
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
206
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
207
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
208
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
209
210
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
211
212
213
};

}
Guolin Ke's avatar
Guolin Ke committed
214
215
216

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
217
DllExport const char* LGBM_GetLastError() {
218
  return LastErrorMsg().c_str();
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
223
224
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
225
  API_BEGIN();
226
227
228
229
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
230
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
231
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
232
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
233
  } else {
Guolin Ke's avatar
Guolin Ke committed
234
235
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
236
  }
237
  API_END();
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
}

DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
  DatesetHandle* out) {
242
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
243
244
  OverallConfig config;
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
245
  *out = loader.LoadFromBinFile(filename, 0, 1);
246
  API_END();
Guolin Ke's avatar
Guolin Ke committed
247
248
249
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
250
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
251
252
253
254
255
256
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
257
  API_BEGIN();
258
259
260
261
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
262
  std::unique_ptr<Dataset> ret;
263
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
264
265
  if (reference == nullptr) {
    // sample data first
266
267
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
268
    auto sample_indices = rand.Sample(nrow, sample_cnt);
269
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
270
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
271
      auto idx = sample_indices[i];
272
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
273
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
274
275
276
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
277
278
      }
    }
Guolin Ke's avatar
Guolin Ke committed
279
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
280
  } else {
281
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
282
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
283
      reinterpret_cast<const Dataset*>(*reference),
284
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
285
286
287
288
289
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
290
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
291
292
293
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
294
  *out = ret.release();
295
  API_END();
296
297
}

298
299
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
300
301
  const int32_t* indices,
  const void* data,
302
303
304
305
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
306
307
308
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
309
  API_BEGIN();
310
311
312
313
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
314
  std::unique_ptr<Dataset> ret;
315
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
316
317
318
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
319
320
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
321
322
323
324
325
326
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
327
328
329
330
331
332
333
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
334
          }
Guolin Ke's avatar
Guolin Ke committed
335
336
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
337
338
339
        }
      }
    }
340
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
341
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
342
  } else {
343
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
344
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
345
      reinterpret_cast<const Dataset*>(*reference),
346
      io_config.is_enable_sparse);
347
348
349
350
351
352
353
354
355
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
356
  *out = ret.release();
357
  API_END();
358
359
}

360
361
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
362
363
  const int32_t* indices,
  const void* data,
364
365
366
367
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
368
369
370
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
371
  API_BEGIN();
372
373
374
375
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
376
  std::unique_ptr<Dataset> ret;
377
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
382
383
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
388
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
389
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
390
    }
Guolin Ke's avatar
Guolin Ke committed
391
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
392
  } else {
393
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
394
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
395
      reinterpret_cast<const Dataset*>(*reference),
396
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
401
402
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
403
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
404
405
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
406
  *out = ret.release();
407
  API_END();
Guolin Ke's avatar
Guolin Ke committed
408
409
}

410
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
411
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
412
  delete reinterpret_cast<Dataset*>(handle);
413
  API_END();
414
415
416
417
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
418
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
419
420
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
421
  API_END();
422
423
424
425
426
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
427
  int64_t num_element,
428
  int type) {
429
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
430
  auto dataset = reinterpret_cast<Dataset*>(handle);
431
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
432
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
433
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
434
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
435
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
436
  }
437
438
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
439
440
441
442
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
443
  int64_t* out_len,
444
445
  const void** out_ptr,
  int* out_type) {
446
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
447
  auto dataset = reinterpret_cast<Dataset*>(handle);
448
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
449
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
450
    *out_type = C_API_DTYPE_FLOAT32;
451
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
452
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
453
    *out_type = C_API_DTYPE_INT32;
454
    is_success = true;
455
  }
456
457
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
458
  API_END();
459
460
461
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
462
  int64_t* out) {
463
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
464
465
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
466
  API_END();
467
468
469
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
470
  int64_t* out) {
471
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
472
473
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
474
  API_END();
Guolin Ke's avatar
Guolin Ke committed
475
}
476
477
478
479
480
481
482


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
483
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
484
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
485
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
486
  *out = ret.release();
487
  API_END();
488
489
}

490
DllExport int LGBM_BoosterCreateFromModelfile(
491
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
492
  int64_t* num_total_model,
493
  BoosterHandle* out) {
494
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
495
496
497
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
  *out = ret.release();
498
  API_END();
499
500
501
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
502
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
503
  delete reinterpret_cast<Booster*>(handle);
504
  API_END();
505
506
}

Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
511
512
513
514
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

534
535
536
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
537
  ref_booster->ResetConfig(parameters);
538
539
540
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
541
542
543
544
545
546
547
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

548
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
549
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
550
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
551
552
553
554
555
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
556
  API_END();
557
558
559
560
561
562
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
563
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
564
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
565
566
567
568
569
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
570
  API_END();
571
572
}

573
574
575
576
577
578
579
580
581
582
583
584
585
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
601
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
602
603
604
605
606
607
608
609
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
610
  int data,
611
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
612
  float* out_results) {
613
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
614
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
615
616
  auto boosting = ref_booster->GetBoosting();
  auto result_buf = boosting->GetEvalAt(data);
617
  *out_len = static_cast<int64_t>(result_buf.size());
618
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
619
    (out_results)[i] = static_cast<float>(result_buf[i]);
620
  }
621
  API_END();
622
623
}

Guolin Ke's avatar
Guolin Ke committed
624
625
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
  int data,
626
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
627
  float* out_result) {
628
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
629
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
630
  int len = 0;
Guolin Ke's avatar
Guolin Ke committed
631
  ref_booster->GetPredictAt(data, out_result, &len);
632
  *out_len = static_cast<int64_t>(len);
633
  API_END();
Guolin Ke's avatar
Guolin Ke committed
634
635
}

636
637
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
638
639
640
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
641
  const char* result_filename) {
642
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
643
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
644
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
645
646
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
647
  API_END();
648
649
}

650
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
651
652
  const void* indptr,
  int indptr_type,
653
654
  const int32_t* indices,
  const void* data,
655
656
657
658
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
659
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
660
661
662
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
663
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
664
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
665
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
666

667
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
668
669
670
671
672
673
674
675
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
676
677
678
679
680
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
681
682
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
683
684
    }
  }
Guolin Ke's avatar
Guolin Ke committed
685
  *out_len = nrow * num_preb_in_one_row;
686
  API_END();
Guolin Ke's avatar
Guolin Ke committed
687
}
688
689
690

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
691
  int data_type,
692
693
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
694
  int is_row_major,
695
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
696
697
698
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
699
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
700
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
701
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
702

703
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
704
705
706
707
708
709
710
711
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
712
713
714
715
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
716
717
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
718
719
    }
  }
Guolin Ke's avatar
Guolin Ke committed
720
  *out_len = nrow * num_preb_in_one_row;
721
  API_END();
Guolin Ke's avatar
Guolin Ke committed
722
}
723
724

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
725
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
726
  const char* filename) {
727
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
728
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
729
  ref_booster->SaveModelToFile(num_iteration, filename);
730
  API_END();
Guolin Ke's avatar
Guolin Ke committed
731
}
732

Guolin Ke's avatar
Guolin Ke committed
733
// ---- start of some help functions
734
735
736

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
737
  if (data_type == C_API_DTYPE_FLOAT32) {
738
739
740
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
741
        std::vector<double> ret(num_col);
742
743
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
744
          ret[i] = static_cast<double>(*(tmp_ptr + i));
745
746
747
748
749
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
750
        std::vector<double> ret(num_col);
751
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
752
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
753
754
755
756
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
757
  } else if (data_type == C_API_DTYPE_FLOAT64) {
758
759
760
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
761
        std::vector<double> ret(num_col);
762
763
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
764
          ret[i] = static_cast<double>(*(tmp_ptr + i));
765
766
767
768
769
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
770
        std::vector<double> ret(num_col);
771
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
772
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
773
774
775
776
777
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
778
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
779
780
781
782
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
783
784
785
786
787
788
789
790
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
791
        }
Guolin Ke's avatar
Guolin Ke committed
792
793
794
      }
      return ret;
    };
795
  }
Guolin Ke's avatar
Guolin Ke committed
796
  return nullptr;
797
798
799
800
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
801
  if (data_type == C_API_DTYPE_FLOAT32) {
802
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
803
    if (indptr_type == C_API_DTYPE_INT32) {
804
805
806
807
808
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
809
        for (int64_t i = start; i < end; ++i) {
810
811
812
813
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
814
    } else if (indptr_type == C_API_DTYPE_INT64) {
815
816
817
818
819
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
820
        for (int64_t i = start; i < end; ++i) {
821
822
823
824
825
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
826
  } else if (data_type == C_API_DTYPE_FLOAT64) {
827
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
828
    if (indptr_type == C_API_DTYPE_INT32) {
829
830
831
832
833
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
834
        for (int64_t i = start; i < end; ++i) {
835
836
837
838
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
839
    } else if (indptr_type == C_API_DTYPE_INT64) {
840
841
842
843
844
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
845
        for (int64_t i = start; i < end; ++i) {
846
847
848
849
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
850
851
852
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
853
854
855
856
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
857
  if (data_type == C_API_DTYPE_FLOAT32) {
858
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
859
    if (col_ptr_type == C_API_DTYPE_INT32) {
860
861
862
863
864
865
866
867
868
869
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
870
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
871
872
873
874
875
876
877
878
879
880
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
881
    } 
Guolin Ke's avatar
Guolin Ke committed
882
  } else if (data_type == C_API_DTYPE_FLOAT64) {
883
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
884
    if (col_ptr_type == C_API_DTYPE_INT32) {
885
886
887
888
889
890
891
892
893
894
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
895
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
896
897
898
899
900
901
902
903
904
905
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
906
907
908
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
909
910
}

911
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
912
913
914
915
916
917
918
919
920
921
922
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
923
}