Commit 8c235f67 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code

parent 7b8bb4f2
......@@ -48,7 +48,7 @@ public:
virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
virtual const score_t* GetScoreAt(int data_idx, data_size_t* out_len) const = 0;
virtual const score_t* GetTrainingScore(data_size_t* out_len) const = 0;
/*!
* \brief Prediction for one record, not sigmoid transform
......
......@@ -39,8 +39,8 @@ DllExport const char* LGBM_GetLastError();
/*!
* \brief load data set from file like the command_line LightGBM do
* \param parameters additional parameters
* \param filename the name of the file
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out a loaded dataset
* \return 0 when success, -1 when failure happens
......@@ -90,7 +90,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
* \param indices findex
* \param data fvalue
* \param float_type 0 for float_32 1 for float_64
* \param nindptr number of rows in the matix + 1
* \param ncol_ptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters
......@@ -116,7 +116,6 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
* \param nrow number of rows
* \param ncol number columns
* \param is_row_major 1 for row major, 0 for column major
* \param missing which value to represent missing value
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
......@@ -151,7 +150,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
* \param handle a instance of dataset
* \param field_name field name, can be label, weight, group
* \param field_data pointer to vector
* \param field_len number of element in field_data
* \param num_element number of element in field_data
* \param type float_32:0, uint32_t:1
* \return 0 when success, -1 when failure happens
*/
......@@ -216,6 +215,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
/*!
* \brief load an exsiting boosting from model file
* \param filename filename of model
* \param out handle of created Booster
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterLoadFromModelfile(
......@@ -232,7 +232,7 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief update the model in one round
* \param handle handle
* \param is_finished 1 means finised
* \param is_finished 1 means finised(cannot split any more)
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
......@@ -243,8 +243,7 @@ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
* \param handle handle
* \param grad gradient statistics
* \param hess second order gradient statistics
* \param float_type 0 for float_32, 1 for float_64
* \param is_finished 1 means finised
* \param is_finished 1 means finised(cannot split any more)
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
......@@ -256,6 +255,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
* \brief get evaluation for training data and validation datas
* \param handle handle
* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param out_len len of output result
* \param out_result the string containing evaluation statistics
* \return 0 when success, -1 when failure happens
*/
......@@ -264,15 +264,27 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
uint64_t* out_len,
double* out_results);
/*!
* \brief get raw score for training data, used to calculate gradients outside
* \param handle handle
* \param out_len len of output result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
uint64_t* out_len,
const float** out_result);
/*!
* \brief make prediction for training data and validation datas
this can be used to support customized eval function / and gradients calculation
this can be used to support customized eval function
* \param handle handle
* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param out_len len of output result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data,
uint64_t* out_len,
const float** out_result);
......@@ -283,12 +295,13 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
* \param indptr pointer to row headers
* \param indices findex
* \param data fvalue
* \param float_type 0:float_32 1:float64
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 1:with transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
......@@ -312,12 +325,13 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param float_type 0:float_32 1:float64
* \param ncol_ptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 1:with transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
......@@ -328,7 +342,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t ncol_ptr,
uint64_t nelem,
uint64_t num_row,
int predict_type,
......@@ -339,12 +353,12 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* \brief make prediction for an new data set
* \param handle handle
* \param data pointer to the data space
* \param float_type 0:float_32 1:float64
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 1:with transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
......
......@@ -294,16 +294,9 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
}
/*! \brief Get prediction result */
const score_t* GBDT::GetScoreAt(int data_idx, data_size_t* out_len) const {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
if (data_idx == 0) {
const score_t* GBDT::GetTrainingScore(data_size_t* out_len) const {
*out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score();
} else {
auto used_idx = data_idx - 1;
*out_len = valid_score_updater_[used_idx]->num_data() * num_class_;
return valid_score_updater_[used_idx]->score();
}
}
void GBDT::Boosting() {
......
......@@ -50,7 +50,7 @@ public:
std::vector<double> GetEvalAt(int data_idx) const override;
/*! \brief Get prediction result */
const score_t* GetScoreAt(int data_idx, data_size_t* out_len) const override;
const score_t* GetTrainingScore(data_size_t* out_len) const override;
/*!
* \brief Predtion for one record without sigmoid transformation
......
......@@ -434,14 +434,13 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
}
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int data,
uint64_t* out_len,
const float** out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
*out_result = boosting->GetScoreAt(data, &len);
*out_result = boosting->GetTrainingScore(&len);
*out_len = static_cast<uint64_t>(len);
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment