Commit 5b4b5d65 authored by Guolin Ke's avatar Guolin Ke
Browse files

32 bit compatibility fix

parent e29ab9f6
...@@ -163,7 +163,7 @@ DllExport int LGBM_DatasetGetSubset( ...@@ -163,7 +163,7 @@ DllExport int LGBM_DatasetGetSubset(
DllExport int LGBM_DatasetSetFeatureNames( DllExport int LGBM_DatasetSetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
const char** feature_names, const char** feature_names,
int64_t num_feature_names); int num_feature_names);
/*! /*!
...@@ -176,7 +176,7 @@ DllExport int LGBM_DatasetSetFeatureNames( ...@@ -176,7 +176,7 @@ DllExport int LGBM_DatasetSetFeatureNames(
DllExport int LGBM_DatasetGetFeatureNames( DllExport int LGBM_DatasetGetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
char** feature_names, char** feature_names,
int64_t* num_feature_names); int* num_feature_names);
/*! /*!
...@@ -208,7 +208,7 @@ DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle, ...@@ -208,7 +208,7 @@ DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
DllExport int LGBM_DatasetSetField(DatasetHandle handle, DllExport int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name, const char* field_name,
const void* field_data, const void* field_data,
int64_t num_element, int num_element,
int type); int type);
/*! /*!
...@@ -222,7 +222,7 @@ DllExport int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -222,7 +222,7 @@ DllExport int LGBM_DatasetSetField(DatasetHandle handle,
*/ */
DllExport int LGBM_DatasetGetField(DatasetHandle handle, DllExport int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name, const char* field_name,
int64_t* out_len, int* out_len,
const void** out_ptr, const void** out_ptr,
int* out_type); int* out_type);
...@@ -233,7 +233,7 @@ DllExport int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -233,7 +233,7 @@ DllExport int LGBM_DatasetGetField(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle, DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
int64_t* out); int* out);
/*! /*!
* \brief get number of features * \brief get number of features
...@@ -242,7 +242,7 @@ DllExport int LGBM_DatasetGetNumData(DatasetHandle handle, ...@@ -242,7 +242,7 @@ DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle, DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int64_t* out); int* out);
// --- start Booster interfaces // --- start Booster interfaces
...@@ -266,7 +266,7 @@ DllExport int LGBM_BoosterCreate(const DatasetHandle train_data, ...@@ -266,7 +266,7 @@ DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
*/ */
DllExport int LGBM_BoosterCreateFromModelfile( DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
int64_t* out_num_iterations, int* out_num_iterations,
BoosterHandle* out); BoosterHandle* out);
...@@ -313,12 +313,12 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, ...@@ -313,12 +313,12 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters); DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters);
/*! /*!
* \brief Get number of class * \brief Get number of class
* \param handle handle * \param handle handle
* \param out_len number of class * \param out_len number of class
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len); DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len);
/*! /*!
* \brief update the model in one round * \brief update the model in one round
...@@ -354,14 +354,14 @@ DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle); ...@@ -354,14 +354,14 @@ DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle);
* \param out_iteration iteration of boosting rounds * \param out_iteration iteration of boosting rounds
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration); DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration);
/*! /*!
* \brief Get number of eval * \brief Get number of eval
* \param out_len total number of eval results * \param out_len total number of eval results
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len); DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len);
/*! /*!
* \brief Get Name of eval * \brief Get Name of eval
...@@ -369,12 +369,12 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len); ...@@ -369,12 +369,12 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len);
* \param out_strs names of eval result, need to pre-allocate memory before call this * \param out_strs names of eval result, need to pre-allocate memory before call this
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs); DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs);
/*! /*!
* \brief get evaluation for training data and validation data * \brief get evaluation for training data and validation data
Note: 1. you should call LGBM_BoosterGetEvalNames first to get the name of evaluation results Note: 1. you should call LGBM_BoosterGetEvalNames first to get the name of evaluation results
2. should pre-allocate memory for out_results, you can get its length by LGBM_BoosterGetEvalCounts 2. should pre-allocate memory for out_results, you can get its length by LGBM_BoosterGetEvalCounts
* \param handle handle * \param handle handle
* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param out_len len of output result * \param out_len len of output result
...@@ -383,7 +383,7 @@ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, c ...@@ -383,7 +383,7 @@ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, c
*/ */
DllExport int LGBM_BoosterGetEval(BoosterHandle handle, DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len, int* out_len,
double* out_results); double* out_results);
/*! /*!
...@@ -401,8 +401,8 @@ DllExport int LGBM_BoosterGetNumPredict(BoosterHandle handle, ...@@ -401,8 +401,8 @@ DllExport int LGBM_BoosterGetNumPredict(BoosterHandle handle,
/*! /*!
* \brief Get prediction for training data and validation data * \brief Get prediction for training data and validation data
this can be used to support customized eval function this can be used to support customized eval function
Note: should pre-allocate memory for out_result, its length is equal to num_class * num_data Note: should pre-allocate memory for out_result, its length is equal to num_class * num_data
* \param handle handle * \param handle handle
* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param out_len len of output result * \param out_len len of output result
...@@ -431,13 +431,13 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -431,13 +431,13 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
const char* result_filename); const char* result_filename);
/*! /*!
* \brief Get number of prediction * \brief Get number of prediction
* \param handle handle * \param handle handle
* \param num_row * \param num_row
* \param predict_type * \param predict_type
* C_API_PREDICT_NORMAL: normal prediction, with transform (if needed) * C_API_PREDICT_NORMAL: normal prediction, with transform (if needed)
* C_API_PREDICT_RAW_SCORE: raw score * C_API_PREDICT_RAW_SCORE: raw score
...@@ -447,14 +447,14 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -447,14 +447,14 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_BoosterCalcNumPredict(BoosterHandle handle, DllExport int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int64_t num_row, int num_row,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len); int64_t* out_len);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
* Note: should pre-allocate memory for out_result, * Note: should pre-allocate memory for out_result,
* for noraml and raw score: its length is equal to num_class * num_data * for noraml and raw score: its length is equal to num_class * num_data
* for leaf index, its length is equal to num_class * num_data * num_iteration * for leaf index, its length is equal to num_class * num_data * num_iteration
* \param handle handle * \param handle handle
...@@ -485,7 +485,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -485,7 +485,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -522,7 +522,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -522,7 +522,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -553,7 +553,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -553,7 +553,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -580,11 +580,11 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -580,11 +580,11 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
DllExport int LGBM_BoosterDumpModel(BoosterHandle handle, DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int64_t* out_len, int* out_len,
char* out_str); char* out_str);
/*! /*!
* \brief Get leaf value * \brief Get leaf value
* \param handle handle * \param handle handle
* \param tree_idx index of tree * \param tree_idx index of tree
* \param leaf_idx index of leaf * \param leaf_idx index of leaf
......
...@@ -336,9 +336,9 @@ public: ...@@ -336,9 +336,9 @@ public:
bool SetIntField(const char* field_name, const int* field_data, data_size_t num_element); bool SetIntField(const char* field_name, const int* field_data, data_size_t num_element);
bool GetFloatField(const char* field_name, int64_t* out_len, const float** out_ptr); bool GetFloatField(const char* field_name, data_size_t* out_len, const float** out_ptr);
bool GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr); bool GetIntField(const char* field_name, data_size_t* out_len, const int** out_ptr);
/*! /*!
* \brief Save current dataset into binary file, will save to "filename.bin" * \brief Save current dataset into binary file, will save to "filename.bin"
......
...@@ -248,12 +248,12 @@ class _InnerPredictor(object): ...@@ -248,12 +248,12 @@ class _InnerPredictor(object):
self.__is_manage_handle = True self.__is_manage_handle = True
if model_file is not None: if model_file is not None:
"""Prediction task""" """Prediction task"""
out_num_iterations = ctypes.c_int64(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile( _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file), c_str(model_file),
ctypes.byref(out_num_iterations), ctypes.byref(out_num_iterations),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
out_num_class = ctypes.c_int64(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
...@@ -262,12 +262,12 @@ class _InnerPredictor(object): ...@@ -262,12 +262,12 @@ class _InnerPredictor(object):
elif booster_handle is not None: elif booster_handle is not None:
self.__is_manage_handle = False self.__is_manage_handle = False
self.handle = booster_handle self.handle = booster_handle
out_num_class = ctypes.c_int64(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.num_class = out_num_class.value self.num_class = out_num_class.value
out_num_iterations = ctypes.c_int64(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration( _safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle, self.handle,
ctypes.byref(out_num_iterations))) ctypes.byref(out_num_iterations)))
...@@ -320,9 +320,9 @@ class _InnerPredictor(object): ...@@ -320,9 +320,9 @@ class _InnerPredictor(object):
_safe_call(_LIB.LGBM_BoosterPredictForFile( _safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle, self.handle,
c_str(data), c_str(data),
int_data_has_header, ctypes.c_int(int_data_has_header),
predict_type, ctypes.c_int(predict_type),
num_iteration, ctypes.c_int(num_iteration),
c_str(f.name))) c_str(f.name)))
lines = f.readlines() lines = f.readlines()
nrow = len(lines) nrow = len(lines)
...@@ -364,9 +364,9 @@ class _InnerPredictor(object): ...@@ -364,9 +364,9 @@ class _InnerPredictor(object):
n_preds = ctypes.c_int64(0) n_preds = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterCalcNumPredict( _safe_call(_LIB.LGBM_BoosterCalcNumPredict(
self.handle, self.handle,
nrow, ctypes.c_int(nrow),
predict_type, ctypes.c_int(predict_type),
num_iteration, ctypes.c_int(num_iteration),
ctypes.byref(n_preds))) ctypes.byref(n_preds)))
return n_preds.value return n_preds.value
...@@ -390,12 +390,12 @@ class _InnerPredictor(object): ...@@ -390,12 +390,12 @@ class _InnerPredictor(object):
_safe_call(_LIB.LGBM_BoosterPredictForMat( _safe_call(_LIB.LGBM_BoosterPredictForMat(
self.handle, self.handle,
ptr_data, ptr_data,
type_ptr_data, ctypes.c_int(type_ptr_data),
mat.shape[0], ctypes.c_int(mat.shape[0]),
mat.shape[1], ctypes.c_int(mat.shape[1]),
C_API_IS_ROW_MAJOR, ctypes.c_int(C_API_IS_ROW_MAJOR),
predict_type, ctypes.c_int(predict_type),
num_iteration, ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -417,15 +417,15 @@ class _InnerPredictor(object): ...@@ -417,15 +417,15 @@ class _InnerPredictor(object):
_safe_call(_LIB.LGBM_BoosterPredictForCSR( _safe_call(_LIB.LGBM_BoosterPredictForCSR(
self.handle, self.handle,
ptr_indptr, ptr_indptr,
type_ptr_indptr, ctypes.c_int32(type_ptr_indptr),
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data, ptr_data,
type_ptr_data, ctypes.c_int(type_ptr_data),
len(csr.indptr), ctypes.c_int64(len(csr.indptr)),
len(csr.data), ctypes.c_int64(len(csr.data)),
csr.shape[1], ctypes.c_int64(csr.shape[1]),
predict_type, ctypes.c_int(predict_type),
num_iteration, ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -447,15 +447,15 @@ class _InnerPredictor(object): ...@@ -447,15 +447,15 @@ class _InnerPredictor(object):
_safe_call(_LIB.LGBM_BoosterPredictForCSC( _safe_call(_LIB.LGBM_BoosterPredictForCSC(
self.handle, self.handle,
ptr_indptr, ptr_indptr,
type_ptr_indptr, ctypes.c_int32(type_ptr_indptr),
csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data, ptr_data,
type_ptr_data, ctypes.c_int(type_ptr_data),
len(csc.indptr), ctypes.c_int64(len(csc.indptr)),
len(csc.data), ctypes.c_int64(len(csc.data)),
csc.shape[0], ctypes.c_int64(csc.shape[0]),
predict_type, ctypes.c_int(predict_type),
num_iteration, ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -660,10 +660,10 @@ class Dataset(object): ...@@ -660,10 +660,10 @@ class Dataset(object):
ptr_data, type_ptr_data = c_float_array(data) ptr_data, type_ptr_data = c_float_array(data)
_safe_call(_LIB.LGBM_DatasetCreateFromMat( _safe_call(_LIB.LGBM_DatasetCreateFromMat(
ptr_data, ptr_data,
type_ptr_data, ctypes.c_int(type_ptr_data),
mat.shape[0], ctypes.c_int(mat.shape[0]),
mat.shape[1], ctypes.c_int(mat.shape[1]),
C_API_IS_ROW_MAJOR, ctypes.c_int(C_API_IS_ROW_MAJOR),
c_str(params_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
...@@ -681,13 +681,13 @@ class Dataset(object): ...@@ -681,13 +681,13 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetCreateFromCSR( _safe_call(_LIB.LGBM_DatasetCreateFromCSR(
ptr_indptr, ptr_indptr,
type_ptr_indptr, ctypes.c_int(type_ptr_indptr),
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data, ptr_data,
type_ptr_data, ctypes.c_int(type_ptr_data),
len(csr.indptr), ctypes.c_int64(len(csr.indptr)),
len(csr.data), ctypes.c_int64(len(csr.data)),
csr.shape[1], ctypes.c_int64(csr.shape[1]),
c_str(params_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
...@@ -705,13 +705,13 @@ class Dataset(object): ...@@ -705,13 +705,13 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetCreateFromCSC( _safe_call(_LIB.LGBM_DatasetCreateFromCSC(
ptr_indptr, ptr_indptr,
type_ptr_indptr, ctypes.c_int(type_ptr_indptr),
csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data, ptr_data,
type_ptr_data, ctypes.c_int(type_ptr_data),
len(csc.indptr), ctypes.c_int64(len(csc.indptr)),
len(csc.data), ctypes.c_int64(len(csc.data)),
csc.shape[0], ctypes.c_int64(csc.shape[0]),
c_str(params_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
...@@ -734,7 +734,7 @@ class Dataset(object): ...@@ -734,7 +734,7 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetGetSubset( _safe_call(_LIB.LGBM_DatasetGetSubset(
handle, handle,
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
used_indices.shape[0], ctypes.c_int(used_indices.shape[0]),
c_str(params_str), c_str(params_str),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
if self.get_label() is None: if self.get_label() is None:
...@@ -830,8 +830,8 @@ class Dataset(object): ...@@ -830,8 +830,8 @@ class Dataset(object):
self.handle, self.handle,
c_str(field_name), c_str(field_name),
None, None,
0, ctypes.c_int(0),
FIELD_TYPE_MAPPER[field_name])) ctypes.c_int(FIELD_TYPE_MAPPER[field_name])))
return return
dtype = np.int32 if field_name == 'group' else np.float32 dtype = np.int32 if field_name == 'group' else np.float32
data = list_to_1d_numpy(data, dtype, name=field_name) data = list_to_1d_numpy(data, dtype, name=field_name)
...@@ -849,8 +849,8 @@ class Dataset(object): ...@@ -849,8 +849,8 @@ class Dataset(object):
self.handle, self.handle,
c_str(field_name), c_str(field_name),
ptr_data, ptr_data,
len(data), ctypes.c_int(len(data)),
type_data)) ctypes.c_int(type_data)))
def get_field(self, field_name): def get_field(self, field_name):
"""Get property from the Dataset. """Get property from the Dataset.
...@@ -865,8 +865,8 @@ class Dataset(object): ...@@ -865,8 +865,8 @@ class Dataset(object):
info : array info : array
A numpy array of information of the data A numpy array of information of the data
""" """
tmp_out_len = ctypes.c_int64() tmp_out_len = ctypes.c_int()
out_type = ctypes.c_int32() out_type = ctypes.c_int()
ret = ctypes.POINTER(ctypes.c_void_p)() ret = ctypes.POINTER(ctypes.c_void_p)()
_safe_call(_LIB.LGBM_DatasetGetField( _safe_call(_LIB.LGBM_DatasetGetField(
self.handle, self.handle,
...@@ -955,7 +955,7 @@ class Dataset(object): ...@@ -955,7 +955,7 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetSetFeatureNames( _safe_call(_LIB.LGBM_DatasetSetFeatureNames(
self.handle, self.handle,
c_array(ctypes.c_char_p, c_feature_name), c_array(ctypes.c_char_p, c_feature_name),
len(feature_name))) ctypes.c_int(len(feature_name))))
def set_label(self, label): def set_label(self, label):
""" """
...@@ -1076,7 +1076,7 @@ class Dataset(object): ...@@ -1076,7 +1076,7 @@ class Dataset(object):
number of rows : int number of rows : int
""" """
if self._is_constructed: if self._is_constructed:
ret = ctypes.c_int64() ret = ctypes.c_int()
_safe_call(_LIB.LGBM_DatasetGetNumData(self.handle, _safe_call(_LIB.LGBM_DatasetGetNumData(self.handle,
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
...@@ -1092,7 +1092,7 @@ class Dataset(object): ...@@ -1092,7 +1092,7 @@ class Dataset(object):
number of columns : int number of columns : int
""" """
if self._is_constructed: if self._is_constructed:
ret = ctypes.c_int64() ret = ctypes.c_int()
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle, _safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle,
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
...@@ -1147,7 +1147,7 @@ class Booster(object): ...@@ -1147,7 +1147,7 @@ class Booster(object):
_safe_call(_LIB.LGBM_BoosterMerge( _safe_call(_LIB.LGBM_BoosterMerge(
self.handle, self.handle,
self.__init_predictor.handle)) self.__init_predictor.handle))
out_num_class = ctypes.c_int64(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
...@@ -1158,12 +1158,12 @@ class Booster(object): ...@@ -1158,12 +1158,12 @@ class Booster(object):
self.__get_eval_info() self.__get_eval_info()
elif model_file is not None: elif model_file is not None:
"""Prediction task""" """Prediction task"""
out_num_iterations = ctypes.c_int64(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile( _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file), c_str(model_file),
ctypes.byref(out_num_iterations), ctypes.byref(out_num_iterations),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
out_num_class = ctypes.c_int64(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
...@@ -1198,7 +1198,7 @@ class Booster(object): ...@@ -1198,7 +1198,7 @@ class Booster(object):
model = state['handle'] model = state['handle']
if model is not None: if model is not None:
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int64(0) out_num_iterations = ctypes.c_int(0)
with _temp_file() as f: with _temp_file() as f:
f.writelines(model) f.writelines(model)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile( _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
...@@ -1335,7 +1335,7 @@ class Booster(object): ...@@ -1335,7 +1335,7 @@ class Booster(object):
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
def current_iteration(self): def current_iteration(self):
out_cur_iter = ctypes.c_int64(0) out_cur_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration( _safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle, self.handle,
ctypes.byref(out_cur_iter))) ctypes.byref(out_cur_iter)))
...@@ -1422,7 +1422,7 @@ class Booster(object): ...@@ -1422,7 +1422,7 @@ class Booster(object):
num_iteration = self.best_iteration num_iteration = self.best_iteration
_safe_call(_LIB.LGBM_BoosterSaveModel( _safe_call(_LIB.LGBM_BoosterSaveModel(
self.handle, self.handle,
num_iteration, ctypes.c_int(num_iteration),
c_str(filename))) c_str(filename)))
def dump_model(self, num_iteration=-1): def dump_model(self, num_iteration=-1):
...@@ -1441,13 +1441,13 @@ class Booster(object): ...@@ -1441,13 +1441,13 @@ class Booster(object):
if num_iteration <= 0: if num_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
buffer_len = 1 << 20 buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel( _safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle, self.handle,
num_iteration, ctypes.c_int(num_iteration),
buffer_len, ctypes.c_int(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
actual_len = tmp_out_len.value actual_len = tmp_out_len.value
...@@ -1457,8 +1457,8 @@ class Booster(object): ...@@ -1457,8 +1457,8 @@ class Booster(object):
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel( _safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle, self.handle,
num_iteration, ctypes.c_int(num_iteration),
actual_len, ctypes.c_int(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
...@@ -1533,10 +1533,10 @@ class Booster(object): ...@@ -1533,10 +1533,10 @@ class Booster(object):
ret = [] ret = []
if self.__num_inner_eval > 0: if self.__num_inner_eval > 0:
result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float64) result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float64)
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetEval( _safe_call(_LIB.LGBM_BoosterGetEval(
self.handle, self.handle,
data_idx, ctypes.c_int(data_idx),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if tmp_out_len.value != self.__num_inner_eval: if tmp_out_len.value != self.__num_inner_eval:
...@@ -1576,7 +1576,7 @@ class Booster(object): ...@@ -1576,7 +1576,7 @@ class Booster(object):
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))
_safe_call(_LIB.LGBM_BoosterGetPredict( _safe_call(_LIB.LGBM_BoosterGetPredict(
self.handle, self.handle,
data_idx, ctypes.c_int(data_idx),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
data_ptr)) data_ptr))
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
...@@ -1590,7 +1590,7 @@ class Booster(object): ...@@ -1590,7 +1590,7 @@ class Booster(object):
""" """
if self.__need_reload_eval_info: if self.__need_reload_eval_info:
self.__need_reload_eval_info = False self.__need_reload_eval_info = False
out_num_eval = ctypes.c_int64(0) out_num_eval = ctypes.c_int(0)
"""Get num of inner evals""" """Get num of inner evals"""
_safe_call(_LIB.LGBM_BoosterGetEvalCounts( _safe_call(_LIB.LGBM_BoosterGetEvalCounts(
self.handle, self.handle,
...@@ -1598,7 +1598,7 @@ class Booster(object): ...@@ -1598,7 +1598,7 @@ class Booster(object):
self.__num_inner_eval = out_num_eval.value self.__num_inner_eval = out_num_eval.value
if self.__num_inner_eval > 0: if self.__num_inner_eval > 0:
"""Get name of evals""" """Get name of evals"""
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int(0)
string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)] string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames( _safe_call(_LIB.LGBM_BoosterGetEvalNames(
......
...@@ -485,11 +485,11 @@ DllExport int LGBM_DatasetGetSubset( ...@@ -485,11 +485,11 @@ DllExport int LGBM_DatasetGetSubset(
DllExport int LGBM_DatasetSetFeatureNames( DllExport int LGBM_DatasetSetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
const char** feature_names, const char** feature_names,
int64_t num_feature_names) { int num_feature_names) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
std::vector<std::string> feature_names_str; std::vector<std::string> feature_names_str;
for (int64_t i = 0; i < num_feature_names; ++i) { for (int i = 0; i < num_feature_names; ++i) {
feature_names_str.emplace_back(feature_names[i]); feature_names_str.emplace_back(feature_names[i]);
} }
dataset->set_feature_names(feature_names_str); dataset->set_feature_names(feature_names_str);
...@@ -499,12 +499,12 @@ DllExport int LGBM_DatasetSetFeatureNames( ...@@ -499,12 +499,12 @@ DllExport int LGBM_DatasetSetFeatureNames(
DllExport int LGBM_DatasetGetFeatureNames( DllExport int LGBM_DatasetGetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
char** feature_names, char** feature_names,
int64_t* num_feature_names) { int* num_feature_names) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
auto inside_feature_name = dataset->feature_names(); auto inside_feature_name = dataset->feature_names();
*num_feature_names = static_cast<int64_t>(inside_feature_name.size()); *num_feature_names = static_cast<int>(inside_feature_name.size());
for (int64_t i = 0; i < *num_feature_names; ++i) { for (int i = 0; i < *num_feature_names; ++i) {
std::strcpy(feature_names[i], inside_feature_name[i].c_str()); std::strcpy(feature_names[i], inside_feature_name[i].c_str());
} }
API_END(); API_END();
...@@ -527,7 +527,7 @@ DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle, ...@@ -527,7 +527,7 @@ DllExport int LGBM_DatasetSaveBinary(DatasetHandle handle,
DllExport int LGBM_DatasetSetField(DatasetHandle handle, DllExport int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name, const char* field_name,
const void* field_data, const void* field_data,
int64_t num_element, int num_element,
int type) { int type) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
...@@ -543,7 +543,7 @@ DllExport int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -543,7 +543,7 @@ DllExport int LGBM_DatasetSetField(DatasetHandle handle,
DllExport int LGBM_DatasetGetField(DatasetHandle handle, DllExport int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name, const char* field_name,
int64_t* out_len, int* out_len,
const void** out_ptr, const void** out_ptr,
int* out_type) { int* out_type) {
API_BEGIN(); API_BEGIN();
...@@ -562,7 +562,7 @@ DllExport int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -562,7 +562,7 @@ DllExport int LGBM_DatasetGetField(DatasetHandle handle,
} }
DllExport int LGBM_DatasetGetNumData(DatasetHandle handle, DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
int64_t* out) { int* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data(); *out = dataset->num_data();
...@@ -570,7 +570,7 @@ DllExport int LGBM_DatasetGetNumData(DatasetHandle handle, ...@@ -570,7 +570,7 @@ DllExport int LGBM_DatasetGetNumData(DatasetHandle handle,
} }
DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle, DllExport int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int64_t* out) { int* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features(); *out = dataset->num_total_features();
...@@ -591,11 +591,11 @@ DllExport int LGBM_BoosterCreate(const DatasetHandle train_data, ...@@ -591,11 +591,11 @@ DllExport int LGBM_BoosterCreate(const DatasetHandle train_data,
DllExport int LGBM_BoosterCreateFromModelfile( DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
int64_t* out_num_iterations, int* out_num_iterations,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
auto ret = std::unique_ptr<Booster>(new Booster(filename)); auto ret = std::unique_ptr<Booster>(new Booster(filename));
*out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->GetCurrentIteration()); *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
*out = ret.release(); *out = ret.release();
API_END(); API_END();
} }
...@@ -640,7 +640,7 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param ...@@ -640,7 +640,7 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) { DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetBoosting()->NumberOfClasses(); *out_len = ref_booster->GetBoosting()->NumberOfClasses();
...@@ -679,21 +679,21 @@ DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) { ...@@ -679,21 +679,21 @@ DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) { DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_iteration = ref_booster->GetBoosting()->GetCurrentIteration(); *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) { DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalCounts(); *out_len = ref_booster->GetEvalCounts();
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) { DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs); *out_len = ref_booster->GetEvalNames(out_strs);
...@@ -702,13 +702,13 @@ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, c ...@@ -702,13 +702,13 @@ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, c
DllExport int LGBM_BoosterGetEval(BoosterHandle handle, DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len, int* out_len,
double* out_results) { double* out_results) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data_idx); auto result_buf = boosting->GetEvalAt(data_idx);
*out_len = static_cast<int64_t>(result_buf.size()); *out_len = static_cast<int>(result_buf.size());
for (size_t i = 0; i < result_buf.size(); ++i) { for (size_t i = 0; i < result_buf.size(); ++i) {
(out_results)[i] = static_cast<double>(result_buf[i]); (out_results)[i] = static_cast<double>(result_buf[i]);
} }
...@@ -738,7 +738,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -738,7 +738,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -762,9 +762,9 @@ int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t n ...@@ -762,9 +762,9 @@ int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t n
} }
DllExport int LGBM_BoosterCalcNumPredict(BoosterHandle handle, DllExport int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int64_t num_row, int num_row,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len) { int64_t* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -782,7 +782,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -782,7 +782,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t, int64_t,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
...@@ -813,7 +813,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -813,7 +813,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
...@@ -855,7 +855,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -855,7 +855,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int64_t num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
...@@ -887,12 +887,12 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -887,12 +887,12 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
DllExport int LGBM_BoosterDumpModel(BoosterHandle handle, DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int64_t* out_len, int* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(num_iteration); std::string model = ref_booster->DumpModel(num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str()); std::strcpy(out_str, model.c_str());
} }
......
...@@ -98,7 +98,7 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si ...@@ -98,7 +98,7 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
return true; return true;
} }
bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const float** out_ptr) { bool Dataset::GetFloatField(const char* field_name, data_size_t* out_len, const float** out_ptr) {
std::string name(field_name); std::string name(field_name);
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) { if (name == std::string("label") || name == std::string("target")) {
...@@ -116,7 +116,7 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa ...@@ -116,7 +116,7 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa
return true; return true;
} }
bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) { bool Dataset::GetIntField(const char* field_name, data_size_t* out_len, const int** out_ptr) {
std::string name(field_name); std::string name(field_name);
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) { if (name == std::string("query") || name == std::string("group")) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment