Commit 5b4ee9db authored by Guolin Ke's avatar Guolin Ke
Browse files

support set/get dataset field with nullptr

parent 4accb9d4
...@@ -419,7 +419,8 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, ...@@ -419,7 +419,8 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
*out_type = C_API_DTYPE_INT32; *out_type = C_API_DTYPE_INT32;
is_success = true; is_success = true;
} }
if (!is_success) { throw std::runtime_error("Field not found or not exist"); } if (!is_success) { throw std::runtime_error("Field not found"); }
if (*out_ptr == nullptr) { *out_len = 0; }
API_END(); API_END();
} }
......
...@@ -101,11 +101,7 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa ...@@ -101,11 +101,7 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa
} else { } else {
return false; return false;
} }
if (*out_ptr != nullptr) {
return true; return true;
} else {
return false;
}
} }
bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) { bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) {
...@@ -117,11 +113,7 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** ...@@ -117,11 +113,7 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int**
} else { } else {
return false; return false;
} }
if (*out_ptr != nullptr) {
return true; return true;
} else {
return false;
}
} }
void Dataset::SaveBinaryFile(const char* bin_filename) { void Dataset::SaveBinaryFile(const char* bin_filename) {
......
...@@ -196,6 +196,12 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -196,6 +196,12 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
void Metadata::SetInitScore(const float* init_score, data_size_t len) { void Metadata::SetInitScore(const float* init_score, data_size_t len) {
// save to nullptr
if (init_score == nullptr || len == 0) {
init_score_.clear();
num_init_score_ = 0;
return;
}
if (len != num_data_ * num_class_) { if (len != num_data_ * num_class_) {
Log::Fatal("Initial score size doesn't match data size"); Log::Fatal("Initial score size doesn't match data size");
} }
...@@ -208,6 +214,9 @@ void Metadata::SetInitScore(const float* init_score, data_size_t len) { ...@@ -208,6 +214,9 @@ void Metadata::SetInitScore(const float* init_score, data_size_t len) {
} }
void Metadata::SetLabel(const float* label, data_size_t len) { void Metadata::SetLabel(const float* label, data_size_t len) {
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
}
if (num_data_ != len) { if (num_data_ != len) {
Log::Fatal("len of label is not same with #data"); Log::Fatal("len of label is not same with #data");
} }
...@@ -219,6 +228,12 @@ void Metadata::SetLabel(const float* label, data_size_t len) { ...@@ -219,6 +228,12 @@ void Metadata::SetLabel(const float* label, data_size_t len) {
} }
void Metadata::SetWeights(const float* weights, data_size_t len) { void Metadata::SetWeights(const float* weights, data_size_t len) {
// save to nullptr
if (weights == nullptr || len == 0) {
weights_.clear();
num_weights_ = 0;
return;
}
if (num_data_ != len) { if (num_data_ != len) {
Log::Fatal("len of weights is not same with #data"); Log::Fatal("len of weights is not same with #data");
} }
...@@ -232,6 +247,12 @@ void Metadata::SetWeights(const float* weights, data_size_t len) { ...@@ -232,6 +247,12 @@ void Metadata::SetWeights(const float* weights, data_size_t len) {
} }
void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len) { void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len) {
// save to nullptr
if (query_boundaries == nullptr || len == 0) {
query_boundaries_.clear();
num_queries_ = 0;
return;
}
data_size_t sum = 0; data_size_t sum = 0;
for (data_size_t i = 0; i < len; ++i) { for (data_size_t i = 0; i < len; ++i) {
sum += query_boundaries[i]; sum += query_boundaries[i];
...@@ -249,6 +270,13 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size ...@@ -249,6 +270,13 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size
} }
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) { void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
// save to nullptr
if (query_id == nullptr || len == 0) {
query_boundaries_.clear();
queries_.clear();
num_queries_ = 0;
return;
}
if (num_data_ != len) { if (num_data_ != len) {
Log::Fatal("len of query id is not same with #data"); Log::Fatal("len of query id is not same with #data");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment