Commit 5b4ee9db authored by Guolin Ke's avatar Guolin Ke
Browse files

support set/get dataset field with nullptr

parent 4accb9d4
......@@ -419,7 +419,8 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
*out_type = C_API_DTYPE_INT32;
is_success = true;
}
if (!is_success) { throw std::runtime_error("Field not found or not exist"); }
if (!is_success) { throw std::runtime_error("Field not found"); }
if (*out_ptr == nullptr) { *out_len = 0; }
API_END();
}
......
......@@ -101,11 +101,7 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa
} else {
return false;
}
if (*out_ptr != nullptr) {
return true;
} else {
return false;
}
}
bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) {
......@@ -117,11 +113,7 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int**
} else {
return false;
}
if (*out_ptr != nullptr) {
return true;
} else {
return false;
}
}
void Dataset::SaveBinaryFile(const char* bin_filename) {
......
......@@ -196,6 +196,12 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
void Metadata::SetInitScore(const float* init_score, data_size_t len) {
// save to nullptr
if (init_score == nullptr || len == 0) {
init_score_.clear();
num_init_score_ = 0;
return;
}
if (len != num_data_ * num_class_) {
Log::Fatal("Initial score size doesn't match data size");
}
......@@ -208,6 +214,9 @@ void Metadata::SetInitScore(const float* init_score, data_size_t len) {
}
void Metadata::SetLabel(const float* label, data_size_t len) {
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
}
if (num_data_ != len) {
Log::Fatal("len of label is not same with #data");
}
......@@ -219,6 +228,12 @@ void Metadata::SetLabel(const float* label, data_size_t len) {
}
void Metadata::SetWeights(const float* weights, data_size_t len) {
// save to nullptr
if (weights == nullptr || len == 0) {
weights_.clear();
num_weights_ = 0;
return;
}
if (num_data_ != len) {
Log::Fatal("len of weights is not same with #data");
}
......@@ -232,6 +247,12 @@ void Metadata::SetWeights(const float* weights, data_size_t len) {
}
void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len) {
// save to nullptr
if (query_boundaries == nullptr || len == 0) {
query_boundaries_.clear();
num_queries_ = 0;
return;
}
data_size_t sum = 0;
for (data_size_t i = 0; i < len; ++i) {
sum += query_boundaries[i];
......@@ -249,6 +270,13 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size
}
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
// save to nullptr
if (query_id == nullptr || len == 0) {
query_boundaries_.clear();
queries_.clear();
num_queries_ = 0;
return;
}
if (num_data_ != len) {
Log::Fatal("len of query id is not same with #data");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment