Unverified Commit bfb346c1 authored by Drew Miller's avatar Drew Miller Committed by GitHub
Browse files

[c_api] Improve ANSI compatibility by avoiding <stdbool.h> (#4697)

* [c_api] Improve ANSI compatibility by avoiding <stdbool.h>

* fixes in response to CI linting

* inline NOLINT instead of separate test

* moving length declaration to non-ANSI C conditional

* [c_api] Align expected return type in `basic.py` with new c_api type.
parent 874e6359
......@@ -23,7 +23,6 @@
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <stdbool.h>
#endif
......@@ -434,12 +433,12 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
/* --- start Booster interfaces */
/*!
* \brief Get boolean representing whether booster is fitting linear trees.
* \brief Get int representing whether booster is fitting linear trees.
* \param handle Handle of booster
* \param[out] out The address to hold linear trees indicator
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out);
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, int* out);
/*!
* \brief Create a new boosting learner.
......@@ -1361,11 +1360,17 @@ static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everythin
#endif
/*!
* \brief Set string message of the last error.
* \note
* This will call unsafe ``sprintf`` when compiled using C standards before C99.
* \param msg Error message
*/
INLINE_FUNCTION void LGBM_SetLastError(const char* msg) {
#if !defined(__cplusplus) && (!defined(__STDC__) || (__STDC_VERSION__ < 199901L))
sprintf(LastErrorMsg(), "%s", msg); /* NOLINT(runtime/printf) */
#else
const int err_buf_len = 512;
snprintf(LastErrorMsg(), err_buf_len, "%s", msg);
#endif
}
#endif /* LIGHTGBM_C_API_H_ */
......@@ -3598,7 +3598,7 @@ class Booster:
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_bool(False)
out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear(
self.handle,
ctypes.byref(out_is_linear)))
......@@ -3607,7 +3607,7 @@ class Booster:
params=self.params,
default_value=None
)
new_params["linear_tree"] = out_is_linear.value
new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, silent=True, params=new_params)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
......
......@@ -1639,10 +1639,14 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_END();
}
int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out) {
int LGBM_BoosterGetLinear(BoosterHandle handle, int* out) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out = ref_booster->GetBoosting()->IsLinear();
if (ref_booster->GetBoosting()->IsLinear()) {
*out = 1;
} else {
*out = 0;
}
API_END();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment