Commit 516e66a5 authored by Guolin Ke's avatar Guolin Ke
Browse files

add array allocate api to solve the max_array_size problem in C#.

parent 974230fa
......@@ -16,6 +16,7 @@
#include <LightGBM/export.h>
typedef void* ArrayHandle;
typedef void* DatasetHandle;
typedef void* BoosterHandle;
......@@ -36,7 +37,6 @@ typedef void* BoosterHandle;
*/
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError();
// --- start Dataset interface
/*!
......@@ -769,4 +769,10 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;
LIGHTGBM_C_EXPORT int LGBM_AllocateArray(int64_t len, int type, ArrayHandle* out);
LIGHTGBM_C_EXPORT int LGBM_CopyToArray(ArrayHandle arr, int type, int64_t start_idx, const void* src, int64_t len);
LIGHTGBM_C_EXPORT int LGBM_FreeArray(ArrayHandle arr, int type);
#endif // LIGHTGBM_C_API_H_
......@@ -1123,6 +1123,56 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_AllocateArray(int64_t len, int type, ArrayHandle* out) {
API_BEGIN();
if (type == C_API_DTYPE_FLOAT32) {
*out = new float[len];
} else if (type == C_API_DTYPE_FLOAT64) {
*out = new double[len];
} else if (type == C_API_DTYPE_INT32) {
*out = new int32_t[len];
} else if (type == C_API_DTYPE_INT64) {
*out = new int64_t[len];
}
API_END();
}
template<typename T>
void Copy(T* dst, const T* src, int64_t len) {
for (int64_t i = 0; i < len; ++i) {
dst[i] = src[i];
}
}
LIGHTGBM_C_EXPORT int LGBM_CopyToArray(ArrayHandle arr, int type, int64_t start_idx, const void* src, int64_t len) {
API_BEGIN();
if (type == C_API_DTYPE_FLOAT32) {
Copy<float>(static_cast<float*>(arr) + start_idx, static_cast<const float*>(src), len);
} else if (type == C_API_DTYPE_FLOAT64) {
Copy<double>(static_cast<double*>(arr) + start_idx, static_cast<const double*>(src), len);
} else if (type == C_API_DTYPE_INT32) {
Copy<int32_t>(static_cast<int32_t*>(arr) + start_idx, static_cast<const int32_t*>(src), len);
} else if (type == C_API_DTYPE_INT64) {
Copy<int64_t>(static_cast<int64_t*>(arr) + start_idx, static_cast<const int64_t*>(src), len);
}
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_FreeArray(ArrayHandle arr, int type) {
API_BEGIN();
if (type == C_API_DTYPE_FLOAT32) {
delete[] static_cast<float*>(arr);
} else if (type == C_API_DTYPE_FLOAT64) {
delete[] static_cast<double*>(arr);
} else if (type == C_API_DTYPE_INT32) {
delete[] static_cast<int32_t*>(arr);
} else if (type == C_API_DTYPE_INT64) {
delete[] static_cast<int64_t*>(arr);
}
API_END();
}
// ---- start of some help functions
std::function<std::vector<double>(int row_idx)>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment