Unverified Commit 5e24b80b authored by Alberto Ferreira's avatar Alberto Ferreira Committed by GitHub
Browse files

[refactor] Reduce code duplication in c_api.cpp (#3539)

* Refactor c_api.cpp with template code

* Further cleanup

* Fix whitespace for linter
parent 716451b1
...@@ -2309,48 +2309,37 @@ int LGBM_NetworkInitWithFunctions(int num_machines, int rank, ...@@ -2309,48 +2309,37 @@ int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
// ---- start of some help functions // ---- start of some help functions
template<typename T>
std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric_helper(const void* data, int num_row, int num_col, int is_row_major) {
const T* data_ptr = reinterpret_cast<const T*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
}
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); return RowFunctionFromDenseMatric_helper<float>(data, num_row, num_col, is_row_major);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); return RowFunctionFromDenseMatric_helper<double>(data, num_row, num_col, is_row_major);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
} }
Log::Fatal("Unknown data type in RowFunctionFromDenseMatric"); Log::Fatal("Unknown data type in RowFunctionFromDenseMatric");
return nullptr; return nullptr;
...@@ -2392,136 +2381,78 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) { ...@@ -2392,136 +2381,78 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
}; };
} }
template<typename T, typename T1, typename T2>
std::function<std::vector<std::pair<int, double>>(T idx)>
RowFunctionFromCSR_helper(const void* indptr, const int32_t* indices, const void* data) {
const T1* data_ptr = reinterpret_cast<const T1*>(data);
const T2* ptr_indptr = reinterpret_cast<const T2*>(indptr);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
}
template<typename T> template<typename T>
std::function<std::vector<std::pair<int, double>>(T idx)> std::function<std::vector<std::pair<int, double>>(T idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) { RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); return RowFunctionFromCSR_helper<T, float, int32_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); return RowFunctionFromCSR_helper<T, float, int64_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} }
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); return RowFunctionFromCSR_helper<T, double, int32_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); return RowFunctionFromCSR_helper<T, double, int64_t>(indptr, indices, data);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} }
} }
Log::Fatal("Unknown data type in RowFunctionFromCSR"); Log::Fatal("Unknown data type in RowFunctionFromCSR");
return nullptr; return nullptr;
} }
template <typename T1, typename T2>
std::function<std::pair<int, double>(int idx)> IterateFunctionFromCSC_helper(const void* col_ptr, const int32_t* indices, const void* data, int col_idx) {
const T1* data_ptr = reinterpret_cast<const T1*>(data);
const T2* ptr_col_ptr = reinterpret_cast<const T2*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
}
std::function<std::pair<int, double>(int idx)> std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) { IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0); CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (col_ptr_type == C_API_DTYPE_INT32) { if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); return IterateFunctionFromCSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} else if (col_ptr_type == C_API_DTYPE_INT64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); return IterateFunctionFromCSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} }
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (col_ptr_type == C_API_DTYPE_INT32) { if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); return IterateFunctionFromCSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} else if (col_ptr_type == C_API_DTYPE_INT64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); return IterateFunctionFromCSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
} }
} }
Log::Fatal("Unknown data type in CSC matrix"); Log::Fatal("Unknown data type in CSC matrix");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment