Commit 8c235f67 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code

parent 7b8bb4f2
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
virtual std::vector<double> GetEvalAt(int data_idx) const = 0; virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
virtual const score_t* GetScoreAt(int data_idx, data_size_t* out_len) const = 0; virtual const score_t* GetTrainingScore(data_size_t* out_len) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
......
...@@ -39,8 +39,8 @@ DllExport const char* LGBM_GetLastError(); ...@@ -39,8 +39,8 @@ DllExport const char* LGBM_GetLastError();
/*! /*!
* \brief load data set from file like the command_line LightGBM do * \brief load data set from file like the command_line LightGBM do
* \param parameters additional parameters
* \param filename the name of the file * \param filename the name of the file
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used * \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out a loaded dataset * \param out a loaded dataset
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
...@@ -90,7 +90,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -90,7 +90,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
* \param indices findex * \param indices findex
* \param data fvalue * \param data fvalue
* \param float_type 0 for float_32 1 for float_64 * \param float_type 0 for float_32 1 for float_64
* \param nindptr number of rows in the matix + 1 * \param ncol_ptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix * \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data * \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters * \param parameters additional parameters
...@@ -116,7 +116,6 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr, ...@@ -116,7 +116,6 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
* \param nrow number of rows * \param nrow number of rows
* \param ncol number columns * \param ncol number columns
* \param is_row_major 1 for row major, 0 for column major * \param is_row_major 1 for row major, 0 for column major
* \param missing which value to represent missing value
* \param parameters additional parameters * \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used * \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset * \param out created dataset
...@@ -151,7 +150,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, ...@@ -151,7 +150,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
* \param handle a instance of dataset * \param handle a instance of dataset
* \param field_name field name, can be label, weight, group * \param field_name field name, can be label, weight, group
* \param field_data pointer to vector * \param field_data pointer to vector
* \param field_len number of element in field_data * \param num_element number of element in field_data
* \param type float_32:0, uint32_t:1 * \param type float_32:0, uint32_t:1
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -216,6 +215,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -216,6 +215,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
/*! /*!
* \brief load an exsiting boosting from model file * \brief load an exsiting boosting from model file
* \param filename filename of model * \param filename filename of model
* \param out handle of created Booster
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterLoadFromModelfile( DllExport int LGBM_BoosterLoadFromModelfile(
...@@ -232,7 +232,7 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle); ...@@ -232,7 +232,7 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*! /*!
* \brief update the model in one round * \brief update the model in one round
* \param handle handle * \param handle handle
* \param is_finished 1 means finised * \param is_finished 1 means finised(cannot split any more)
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished); DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
...@@ -243,8 +243,7 @@ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished); ...@@ -243,8 +243,7 @@ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
* \param handle handle * \param handle handle
* \param grad gradient statistics * \param grad gradient statistics
* \param hess second order gradient statistics * \param hess second order gradient statistics
* \param float_type 0 for float_32, 1 for float_64 * \param is_finished 1 means finised(cannot split any more)
* \param is_finished 1 means finised
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
...@@ -256,6 +255,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -256,6 +255,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
* \brief get evaluation for training data and validation datas * \brief get evaluation for training data and validation datas
* \param handle handle * \param handle handle
* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param data 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param out_len len of output result
* \param out_result the string containing evaluation statistics * \param out_result the string containing evaluation statistics
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -264,15 +264,27 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, ...@@ -264,15 +264,27 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
uint64_t* out_len, uint64_t* out_len,
double* out_results); double* out_results);
/*!
* \brief get raw score for training data, used to calculate gradients outside
* \param handle handle
* \param out_len len of output result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
uint64_t* out_len,
const float** out_result);
/*! /*!
* \brief make prediction for training data and validation datas * \brief make prediction for training data and validation datas
this can be used to support customized eval function / and gradients calculation this can be used to support customized eval function
* \param handle handle * \param handle handle
* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param data 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param out_len len of output result
* \param out_result used to set a pointer to array * \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterGetScore(BoosterHandle handle, DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data, int data,
uint64_t* out_len, uint64_t* out_len,
const float** out_result); const float** out_result);
...@@ -283,12 +295,13 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle, ...@@ -283,12 +295,13 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
* \param indptr pointer to row headers * \param indptr pointer to row headers
* \param indices findex * \param indices findex
* \param data fvalue * \param data fvalue
* \param float_type 0:float_32 1:float64
* \param nindptr number of rows in the matix + 1 * \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix * \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data * \param num_col number of columns; when it's set to 0, then guess from data
* \param predict_type * \param predict_type
* 0:raw score * 0:raw score
* 1:with sigmoid transform(if needed) * 1:with transform(if needed)
* 2:leaf index * 2:leaf index
* \param n_used_trees number of used tree * \param n_used_trees number of used tree
* \param out_result used to set a pointer to array * \param out_result used to set a pointer to array
...@@ -312,12 +325,13 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -312,12 +325,13 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \param col_ptr pointer to col headers * \param col_ptr pointer to col headers
* \param indices findex * \param indices findex
* \param data fvalue * \param data fvalue
* \param nindptr number of rows in the matix + 1 * \param float_type 0:float_32 1:float64
* \param ncol_ptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix * \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data * \param num_row number of rows; when it's set to 0, then guess from data
* \param predict_type * \param predict_type
* 0:raw score * 0:raw score
* 1:with sigmoid transform(if needed) * 1:with transform(if needed)
* 2:leaf index * 2:leaf index
* \param n_used_trees number of used tree * \param n_used_trees number of used tree
* \param out_result used to set a pointer to array * \param out_result used to set a pointer to array
...@@ -328,7 +342,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -328,7 +342,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int float_type, int float_type,
uint64_t nindptr, uint64_t ncol_ptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_row, uint64_t num_row,
int predict_type, int predict_type,
...@@ -339,12 +353,12 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -339,12 +353,12 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* \brief make prediction for an new data set * \brief make prediction for an new data set
* \param handle handle * \param handle handle
* \param data pointer to the data space * \param data pointer to the data space
* \param float_type 0:float_32 1:float64
* \param nrow number of rows * \param nrow number of rows
* \param ncol number columns * \param ncol number columns
* \param missing which value to represent missing value
* \param predict_type * \param predict_type
* 0:raw score * 0:raw score
* 1:with sigmoid transform(if needed) * 1:with transform(if needed)
* 2:leaf index * 2:leaf index
* \param n_used_trees number of used tree * \param n_used_trees number of used tree
* \param out_result used to set a pointer to array * \param out_result used to set a pointer to array
......
...@@ -294,16 +294,9 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const { ...@@ -294,16 +294,9 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
} }
/*! \brief Get prediction result */ /*! \brief Get prediction result */
const score_t* GBDT::GetScoreAt(int data_idx, data_size_t* out_len) const { const score_t* GBDT::GetTrainingScore(data_size_t* out_len) const {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
if (data_idx == 0) {
*out_len = train_score_updater_->num_data() * num_class_; *out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score(); return train_score_updater_->score();
} else {
auto used_idx = data_idx - 1;
*out_len = valid_score_updater_[used_idx]->num_data() * num_class_;
return valid_score_updater_[used_idx]->score();
}
} }
void GBDT::Boosting() { void GBDT::Boosting() {
......
...@@ -50,7 +50,7 @@ public: ...@@ -50,7 +50,7 @@ public:
std::vector<double> GetEvalAt(int data_idx) const override; std::vector<double> GetEvalAt(int data_idx) const override;
/*! \brief Get prediction result */ /*! \brief Get prediction result */
const score_t* GetScoreAt(int data_idx, data_size_t* out_len) const override; const score_t* GetTrainingScore(data_size_t* out_len) const override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
......
...@@ -434,14 +434,13 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, ...@@ -434,14 +434,13 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
} }
DllExport int LGBM_BoosterGetScore(BoosterHandle handle, DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int data,
uint64_t* out_len, uint64_t* out_len,
const float** out_result) { const float** out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
int len = 0; int len = 0;
*out_result = boosting->GetScoreAt(data, &len); *out_result = boosting->GetTrainingScore(&len);
*out_len = static_cast<uint64_t>(len); *out_len = static_cast<uint64_t>(len);
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment