Commit 9d375069 authored by Guolin Ke's avatar Guolin Ke
Browse files

add LGBM_DatasetCreateFromSampledMat api.

parent 9f69165b
......@@ -48,13 +48,33 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError();
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling matrix, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param data pointer to the data space
* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64
* \param num_sample_row number of rows
* \param ncol number columns
* \param num_total_row number of total rows
* \param parameters additional parameters
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling csc data, if num_sample_row == num_total_row, will construct this dataset.
* \brief create a empty dataset by sampling CSR data, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param indptr pointer to row headers
* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64
* \param indices findex
......@@ -69,16 +89,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t n_sample_elem,
int64_t num_col,
int64_t num_total_row,
const char* parameters,
DatasetHandle* out);
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t n_sample_elem,
int64_t num_col,
int64_t num_total_row,
const char* parameters,
DatasetHandle* out);
/*!
* \brief create a empty dataset by reference Dataset
......@@ -88,8 +108,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row,
DatasetHandle* out);
int64_t num_total_row,
DatasetHandle* out);
/*!
* \brief push data to existing dataset, if nrow + start_row == num_total_row, will call dataset->FinishLoad
......@@ -102,11 +122,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int32_t start_row);
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int32_t start_row);
/*!
* \brief push data to existing dataset, if nrow + start_row == num_total_row, will call dataset->FinishLoad
......@@ -123,15 +143,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
int64_t start_row);
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
int64_t start_row);
/*!
* \brief create a dataset from CSR format
......@@ -149,16 +169,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
/*!
* \brief create a dataset from CSC format
......@@ -176,16 +196,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
/*!
* \brief create dataset from dense matrix
......@@ -200,13 +220,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out);
/*!
* \brief Create subset of a data
......@@ -263,7 +283,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle);
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
const char* filename);
const char* filename);
/*!
* \brief set vector to a content in info
......@@ -277,10 +297,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name,
const void* field_data,
int num_element,
int type);
const char* field_name,
const void* field_data,
int num_element,
int type);
/*!
* \brief get info vector from dataset
......@@ -292,10 +312,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name,
int* out_len,
const void** out_ptr,
int* out_type);
const char* field_name,
int* out_len,
const void** out_ptr,
int* out_type);
/*!
* \brief get number of data.
......@@ -304,7 +324,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out);
int* out);
/*!
* \brief get number of features
......@@ -313,7 +333,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out);
int* out);
// --- start Booster interfaces
......@@ -325,8 +345,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
const char* parameters,
BoosterHandle* out);
const char* parameters,
BoosterHandle* out);
/*!
* \brief load an existing boosting from model file
......@@ -366,7 +386,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle);
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle);
BoosterHandle other_handle);
/*!
* \brief Add new validation to booster
......@@ -375,7 +395,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatasetHandle valid_data);
const DatasetHandle valid_data);
/*!
* \brief Reset training data for booster
......@@ -384,7 +404,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatasetHandle train_data);
const DatasetHandle train_data);
/*!
* \brief Reset config for current booster
......@@ -420,9 +440,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad,
const float* hess,
int* is_finished);
const float* grad,
const float* hess,
int* is_finished);
/*!
* \brief Rollback one iteration
......@@ -479,9 +499,9 @@ Note: 1. you should call LGBM_BoosterGetEvalNames first to get the name of evalu
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx,
int* out_len,
double* out_results);
int data_idx,
int* out_len,
double* out_results);
/*!
* \brief Get number of predict for inner dataset
......@@ -493,8 +513,8 @@ Note: should pre-allocate memory for out_result, its length is equal to num_cla
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
int data_idx,
int64_t* out_len);
int data_idx,
int64_t* out_len);
/*!
* \brief Get prediction for training data and validation data
......@@ -507,9 +527,9 @@ Note: should pre-allocate memory for out_result, its length is equal to num_cla
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
int data_idx,
int64_t* out_len,
double* out_result);
int data_idx,
int64_t* out_len,
double* out_result);
/*!
* \brief make prediction for file
......@@ -525,11 +545,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename,
int data_has_header,
int predict_type,
int num_iteration,
const char* result_filename);
const char* data_filename,
int data_has_header,
int predict_type,
int num_iteration,
const char* result_filename);
/*!
* \brief Get number of prediction
......@@ -544,10 +564,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row,
int predict_type,
int num_iteration,
int64_t* out_len);
int num_row,
int predict_type,
int num_iteration,
int64_t* out_len);
/*!
* \brief make prediction for an new data set
......@@ -573,18 +593,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
/*!
* \brief make prediction for an new data set
......@@ -610,18 +630,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
/*!
* \brief make prediction for an new data set
......@@ -644,15 +664,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
/*!
* \brief save model into file
......@@ -662,8 +682,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration,
const char* filename);
int num_iteration,
const char* filename);
/*!
* \brief save model to string
......@@ -675,10 +695,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
char* out_str);
int num_iteration,
int buffer_len,
int* out_len,
char* out_str);
/*!
* \brief dump model to json
......@@ -690,10 +710,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
char* out_str);
int num_iteration,
int buffer_len,
int* out_len,
char* out_str);
/*!
* \brief Get leaf value
......@@ -704,9 +724,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
double* out_val);
int tree_idx,
int leaf_idx,
double* out_val);
/*!
* \brief Set leaf value
......@@ -717,9 +737,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
double val);
int tree_idx,
int leaf_idx,
double val);
#if defined(_MSC_VER)
// exception handle and error msg
......
......@@ -36,7 +36,7 @@ public:
}
Booster(const Dataset* train_data,
const char* parameters) {
const char* parameters) {
auto param = ConfigBase::Str2Map(parameters);
config_.Set(param);
if (config_.num_threads > 0) {
......@@ -52,7 +52,7 @@ public:
// initialize the boosting
boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
ResetTrainingData(train_data);
}
......@@ -71,7 +71,7 @@ public:
train_data_ = train_data;
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function");
}
......@@ -92,7 +92,7 @@ public:
train_metric_.shrink_to_fit();
// reset the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
void ResetConfig(const char* parameters) {
......@@ -116,7 +116,7 @@ public:
if (param.count("objective")) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function");
}
......@@ -127,7 +127,7 @@ public:
}
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
......@@ -142,7 +142,7 @@ public:
}
valid_metrics_.back().shrink_to_fit();
boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
}
bool TrainOneIter() {
......@@ -266,13 +266,13 @@ RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int d
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const void* data, int data_type, int64_t nindptr, int64_t nelem);
const void* data, int data_type, int64_t nindptr, int64_t nelem);
// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
~CSC_RowIterator() {}
// return value at idx, only can access by ascent order
double Get(int idx);
......@@ -293,9 +293,9 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
......@@ -305,25 +305,59 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
*out = loader.LoadFromFile(filename);
} else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename,
reinterpret_cast<const Dataset*>(reference));
reinterpret_cast<const Dataset*>(reference));
}
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out) {
if (num_sample_row == num_total_row) {
return LGBM_DatasetCreateFromMat(data, data_type, num_total_row, ncol, 1, parameters, nullptr, out);
} else {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto get_row_fun = RowFunctionFromDenseMatric(data, num_sample_row, ncol, data_type, 1);
std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol);
for (int i = 0; i < num_sample_row; ++i) {
auto row = get_row_fun(i);
for (size_t idx = 0; idx < row.size(); ++idx) {
if (std::fabs(row[idx]) > kEpsilon) {
sample_values[idx].emplace_back(row[idx]);
sample_idx[idx].emplace_back(i);
}
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t n_sample_elem,
int64_t num_col,
int64_t num_total_row,
const char* parameters,
DatasetHandle* out) {
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t n_sample_elem,
int64_t num_col,
int64_t num_total_row,
const char* parameters,
DatasetHandle* out) {
if (nindptr - 1 == num_total_row) {
return LGBM_DatasetCreateFromCSR(indptr, indptr_type, indices, data,
data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out);
data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out);
} else {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
......@@ -349,15 +383,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row,
DatasetHandle* out) {
int64_t num_total_row,
DatasetHandle* out) {
API_BEGIN();
std::unique_ptr<Dataset> ret;
ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
......@@ -367,11 +401,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
}
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int32_t start_row) {
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int32_t start_row) {
API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
......@@ -388,15 +422,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int64_t start_row) {
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int64_t start_row) {
API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
......@@ -406,7 +440,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const int tid = omp_get_thread_num();
auto one_row = get_row_fun(i);
p_dataset->PushOneRow(tid,
static_cast<data_size_t>(start_row + i), one_row);
static_cast<data_size_t>(start_row + i), one_row);
}
if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
p_dataset->FinishLoad();
......@@ -415,13 +449,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
......@@ -465,16 +499,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
......@@ -524,16 +558,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
const char* parameters,
const DatasetHandle reference,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
......@@ -547,7 +581,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
for (int j = 0; j < sample_cnt; j++) {
......@@ -641,7 +675,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
}
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
const char* filename) {
const char* filename) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename);
......@@ -649,10 +683,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name,
const void* field_data,
int num_element,
int type) {
const char* field_name,
const void* field_data,
int num_element,
int type) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
......@@ -668,10 +702,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name,
int* out_len,
const void** out_ptr,
int* out_type) {
const char* field_name,
int* out_len,
const void** out_ptr,
int* out_type) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
......@@ -691,7 +725,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out) {
int* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data();
......@@ -699,7 +733,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out) {
int* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features();
......@@ -709,8 +743,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
// ---- start of booster
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
const char* parameters,
BoosterHandle* out) {
const char* parameters,
BoosterHandle* out) {
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
......@@ -748,7 +782,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
}
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle) {
BoosterHandle other_handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
......@@ -757,7 +791,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatasetHandle valid_data) {
const DatasetHandle valid_data) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
......@@ -766,7 +800,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatasetHandle train_data) {
const DatasetHandle train_data) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
......@@ -800,9 +834,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi
}
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad,
const float* hess,
int* is_finished) {
const float* grad,
const float* hess,
int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) {
......@@ -856,9 +890,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_l
}
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx,
int* out_len,
double* out_results) {
int data_idx,
int* out_len,
double* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
......@@ -871,8 +905,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
int data_idx,
int64_t* out_len) {
int data_idx,
int64_t* out_len) {
API_BEGIN();
auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
*out_len = boosting->GetNumPredictAt(data_idx);
......@@ -880,9 +914,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
int data_idx,
int64_t* out_len,
double* out_result) {
int data_idx,
int64_t* out_len,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->GetPredictAt(data_idx, out_result, out_len);
......@@ -890,11 +924,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename,
int data_has_header,
int predict_type,
int num_iteration,
const char* result_filename) {
const char* data_filename,
int data_has_header,
int predict_type,
int num_iteration,
const char* result_filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
......@@ -917,10 +951,10 @@ int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t n
}
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row,
int predict_type,
int num_iteration,
int64_t* out_len) {
int num_row,
int predict_type,
int num_iteration,
int64_t* out_len) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
......@@ -928,18 +962,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
......@@ -959,18 +993,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
......@@ -978,7 +1012,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int ncol = static_cast<int>(ncol_ptr - 1);
Threading::For<int64_t>(0, num_row,
[&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
[&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
(int, data_size_t start, data_size_t end) {
std::vector<CSC_RowIterator> iterators;
for (int j = 0; j < ncol; ++j) {
......@@ -1004,15 +1038,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
......@@ -1031,8 +1065,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration,
const char* filename) {
int num_iteration,
const char* filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_iteration, filename);
......@@ -1040,10 +1074,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
char* out_str) {
int num_iteration,
int buffer_len,
int* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(num_iteration);
......@@ -1055,10 +1089,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
char* out_str) {
int num_iteration,
int buffer_len,
int* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(num_iteration);
......@@ -1070,9 +1104,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
double* out_val) {
int tree_idx,
int leaf_idx,
double* out_val) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
......@@ -1080,9 +1114,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
}
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
double val) {
int tree_idx,
int leaf_idx,
double val) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
......@@ -1096,7 +1130,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
......@@ -1105,7 +1139,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
......@@ -1116,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
} else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
......@@ -1125,7 +1159,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
......@@ -1141,7 +1175,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
if (inner_function != nullptr) {
return [inner_function](int row_idx) {
return [inner_function] (int row_idx) {
auto raw_values = inner_function(row_idx);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
......@@ -1161,7 +1195,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1172,7 +1206,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
};
} else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1186,7 +1220,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1197,7 +1231,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
};
} else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1220,7 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1233,7 +1267,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1249,7 +1283,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1262,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1277,7 +1311,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
}
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment