Commit de114be5 authored by Guolin Ke's avatar Guolin Ke
Browse files

add nullptr check for get_field

parent a178b75b
......@@ -143,8 +143,13 @@ public:
* \brief Get weights, if not exists, will return nullptr
* \return Pointer of weights
*/
inline const float* weights()
const { return weights_.data(); }
inline const float* weights() const {
if (weights_.size() > 0) {
return weights_.data();
} else {
return nullptr;
}
}
/*!
* \brief Get data boundaries on queries, if not exists, will return nullptr
......@@ -153,8 +158,13 @@ public:
* is the data indices for query i.
* \return Pointer of data boundaries on queries
*/
inline const data_size_t* query_boundaries()
const { return query_boundaries_.data(); }
inline const data_size_t* query_boundaries() const {
if (query_boundaries_.size() > 0) {
return query_boundaries_.data();
} else {
return nullptr;
}
}
/*!
* \brief Get Number of queries
......@@ -166,13 +176,25 @@ public:
* \brief Get weights for queries, if not exists, will return nullptr
* \return Pointer of weights for queries
*/
inline const float* query_weights() const { return query_weights_.data(); }
inline const float* query_weights() const {
if (query_weights_.size() > 0) {
return query_weights_.data();
} else {
return nullptr;
}
}
/*!
* \brief Get initial scores, if not exists, will return nullptr
* \return Pointer of initial scores
*/
inline const float* init_score() const { return init_score_.data(); }
inline const float* init_score() const {
if (init_score_.size() > 0) {
return init_score_.data();
} else {
return nullptr;
}
}
/*! \brief Disable copy */
Metadata& operator=(const Metadata&) = delete;
......
......@@ -387,7 +387,7 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
*out_type = C_API_DTYPE_INT32;
is_success = true;
}
if (!is_success) { throw std::runtime_error("Field not found"); }
if (!is_success) { throw std::runtime_error("Field not found or not exist"); }
API_END();
}
......
......@@ -101,7 +101,11 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa
} else {
return false;
}
return true;
if (*out_ptr != nullptr) {
return true;
} else {
return false;
}
}
bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) {
......@@ -109,11 +113,15 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int**
name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) {
*out_ptr = metadata_.query_boundaries();
*out_len = num_data_;
*out_len = metadata_.num_queries();
} else {
return false;
}
if (*out_ptr != nullptr) {
return true;
} else {
return false;
}
return true;
}
void Dataset::SaveBinaryFile(const char* bin_filename) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment