Unverified Commit 5e24b80b authored by Alberto Ferreira's avatar Alberto Ferreira Committed by GitHub
Browse files

[refactor] Reduce code duplication in c_api.cpp (#3539)

* Refactor c_api.cpp with template code

* Further cleanup

* Fix whitespace for linter
parent 716451b1
...@@ -2309,10 +2309,11 @@ int LGBM_NetworkInitWithFunctions(int num_machines, int rank, ...@@ -2309,10 +2309,11 @@ int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
// ---- start of some help functions // ---- start of some help functions
template<typename T>
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowFunctionFromDenseMatric_helper(const void* data, int num_row, int num_col, int is_row_major) {
if (data_type == C_API_DTYPE_FLOAT32) { const T* data_ptr = reinterpret_cast<const T*>(data);
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [=] (int row_idx) { return [=] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
...@@ -2331,26 +2332,14 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -2331,26 +2332,14 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} }
}
std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == C_API_DTYPE_FLOAT32) {
return RowFunctionFromDenseMatric_helper<float>(data, num_row, num_col, is_row_major);
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); return RowFunctionFromDenseMatric_helper<double>(data, num_row, num_col, is_row_major);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
} }
Log::Fatal("Unknown data type in RowFunctionFromDenseMatric"); Log::Fatal("Unknown data type in RowFunctionFromDenseMatric");
return nullptr; return nullptr;
...@@ -2392,13 +2381,11 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) { ...@@ -2392,13 +2381,11 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
}; };
} }
template<typename T> template<typename T, typename T1, typename T2>
std::function<std::vector<std::pair<int, double>>(T idx)> std::function<std::vector<std::pair<int, double>>(T idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) { RowFunctionFromCSR_helper(const void* indptr, const int32_t* indices, const void* data) {
if (data_type == C_API_DTYPE_FLOAT32) { const T1* data_ptr = reinterpret_cast<const T1*>(data);
const float* data_ptr = reinterpret_cast<const float*>(data); const T2* ptr_indptr = reinterpret_cast<const T2*>(indptr);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [=] (T idx) { return [=] (T idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
...@@ -2411,64 +2398,34 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -2411,64 +2398,34 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
} }
return ret; return ret;
}; };
}
template<typename T>
std::function<std::vector<std::pair<int, double>>(T idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
if (data_type == C_API_DTYPE_FLOAT32) {
if (indptr_type == C_API_DTYPE_INT32) {
return RowFunctionFromCSR_helper<T, float, int32_t>(indptr, indices, data);
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); return RowFunctionFromCSR_helper<T, float, int64_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} }
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); return RowFunctionFromCSR_helper<T, double, int32_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); return RowFunctionFromCSR_helper<T, double, int64_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} }
} }
Log::Fatal("Unknown data type in RowFunctionFromCSR"); Log::Fatal("Unknown data type in RowFunctionFromCSR");
return nullptr; return nullptr;
} }
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0); template <typename T1, typename T2>
if (data_type == C_API_DTYPE_FLOAT32) { std::function<std::pair<int, double>(int idx)> IterateFunctionFromCSC_helper(const void* col_ptr, const int32_t* indices, const void* data, int col_idx) {
const float* data_ptr = reinterpret_cast<const float*>(data); const T1* data_ptr = reinterpret_cast<const T1*>(data);
if (col_ptr_type == C_API_DTYPE_INT32) { const T2* ptr_col_ptr = reinterpret_cast<const T2*>(col_ptr);
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) { return [=] (int offset) {
...@@ -2480,48 +2437,22 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -2480,48 +2437,22 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
double val = static_cast<double>(data_ptr[i]); double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val); return std::make_pair(idx, val);
}; };
}
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (data_type == C_API_DTYPE_FLOAT32) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return IterateFunctionFromCSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); return IterateFunctionFromCSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} }
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (col_ptr_type == C_API_DTYPE_INT32) { if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); return IterateFunctionFromCSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} else if (col_ptr_type == C_API_DTYPE_INT64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); return IterateFunctionFromCSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} }
} }
Log::Fatal("Unknown data type in CSC matrix"); Log::Fatal("Unknown data type in CSC matrix");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment