Unverified Commit bfb346c1 authored by Drew Miller's avatar Drew Miller Committed by GitHub
Browse files

[c_api] Improve ANSI compatibility by avoiding <stdbool.h> (#4697)

* [c_api] Improve ANSI compatibility by avoiding <stdbool.h>

* fixes in response to CI linting

* inline NOLINT instead of separate test

* moving length declaration to non-ANSI C conditional

* [c_api] Align expected return type in `basic.py` with new c_api type.
parent 874e6359
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <stdbool.h>
#endif #endif
...@@ -434,12 +433,12 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetAddFeaturesFrom(DatasetHandle target, ...@@ -434,12 +433,12 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
/* --- start Booster interfaces */ /* --- start Booster interfaces */
/*! /*!
* \brief Get boolean representing whether booster is fitting linear trees. * \brief Get int representing whether booster is fitting linear trees.
* \param handle Handle of booster * \param handle Handle of booster
* \param[out] out The address to hold linear trees indicator * \param[out] out The address to hold linear trees indicator
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out); LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, int* out);
/*! /*!
* \brief Create a new boosting learner. * \brief Create a new boosting learner.
...@@ -1361,11 +1360,17 @@ static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everythin ...@@ -1361,11 +1360,17 @@ static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everythin
#endif #endif
/*! /*!
* \brief Set string message of the last error. * \brief Set string message of the last error.
* \note
* This will call unsafe ``sprintf`` when compiled using C standards before C99.
* \param msg Error message * \param msg Error message
*/ */
INLINE_FUNCTION void LGBM_SetLastError(const char* msg) { INLINE_FUNCTION void LGBM_SetLastError(const char* msg) {
#if !defined(__cplusplus) && (!defined(__STDC__) || (__STDC_VERSION__ < 199901L))
sprintf(LastErrorMsg(), "%s", msg); /* NOLINT(runtime/printf) */
#else
const int err_buf_len = 512; const int err_buf_len = 512;
snprintf(LastErrorMsg(), err_buf_len, "%s", msg); snprintf(LastErrorMsg(), err_buf_len, "%s", msg);
#endif
} }
#endif /* LIGHTGBM_C_API_H_ */ #endif /* LIGHTGBM_C_API_H_ */
...@@ -3598,7 +3598,7 @@ class Booster: ...@@ -3598,7 +3598,7 @@ class Booster:
predictor = self._to_predictor(deepcopy(kwargs)) predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True) leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_bool(False) out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear( _safe_call(_LIB.LGBM_BoosterGetLinear(
self.handle, self.handle,
ctypes.byref(out_is_linear))) ctypes.byref(out_is_linear)))
...@@ -3607,7 +3607,7 @@ class Booster: ...@@ -3607,7 +3607,7 @@ class Booster:
params=self.params, params=self.params,
default_value=None default_value=None
) )
new_params["linear_tree"] = out_is_linear.value new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, silent=True, params=new_params) train_set = Dataset(data, label, silent=True, params=new_params)
new_params['refit_decay_rate'] = decay_rate new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set) new_booster = Booster(new_params, train_set)
......
...@@ -1639,10 +1639,14 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) { ...@@ -1639,10 +1639,14 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_END(); API_END();
} }
int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out) { int LGBM_BoosterGetLinear(BoosterHandle handle, int* out) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out = ref_booster->GetBoosting()->IsLinear(); if (ref_booster->GetBoosting()->IsLinear()) {
*out = 1;
} else {
*out = 0;
}
API_END(); API_END();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment