"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "aa774f38f88ceedb3eb6ee19f94d94c8b1199bf9"
Commit 1c1749db authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in filter bin

parent 8a0f07ae
...@@ -425,8 +425,8 @@ inline static double ApproximateHessianWithGaussian(const double y, const double ...@@ -425,8 +425,8 @@ inline static double ApproximateHessianWithGaussian(const double y, const double
} }
template <typename T> template <typename T>
inline static T** Vector2Ptr(std::vector<std::vector<T>>& data) { inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>& data) {
T** ptr = new T*[data.size()]; std::vector<T*> ptr(data.size());
for (size_t i = 0; i < data.size(); ++i) { for (size_t i = 0; i < data.size(); ++i) {
ptr[i] = data[i].data(); ptr[i] = data[i].data();
} }
......
...@@ -423,8 +423,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, ...@@ -423,8 +423,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
} }
} }
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx), Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(), Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt, nrow));
...@@ -487,8 +487,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -487,8 +487,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
} }
CHECK(num_col >= static_cast<int>(sample_values.size())); CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx), Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(), Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt, nrow));
...@@ -546,8 +546,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -546,8 +546,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
} }
} }
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx), Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(), Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt, nrow));
......
...@@ -49,18 +49,14 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin ...@@ -49,18 +49,14 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin
int sum_left = 0; int sum_left = 0;
for (size_t i = 0; i < cnt_in_bin.size() - 1; ++i) { for (size_t i = 0; i < cnt_in_bin.size() - 1; ++i) {
sum_left += cnt_in_bin[i]; sum_left += cnt_in_bin[i];
if (sum_left >= filter_cnt) { if (sum_left >= filter_cnt && total_cnt - sum_left >= filter_cnt) {
return false;
} else if (total_cnt - sum_left >= filter_cnt) {
return false; return false;
} }
} }
} else { } else {
for (size_t i = 0; i < cnt_in_bin.size() - 1; ++i) { for (size_t i = 0; i < cnt_in_bin.size() - 1; ++i) {
int sum_left = cnt_in_bin[i]; int sum_left = cnt_in_bin[i];
if (sum_left >= filter_cnt) { if (sum_left >= filter_cnt && total_cnt - sum_left >= filter_cnt) {
return false;
} else if (total_cnt - sum_left >= filter_cnt) {
return false; return false;
} }
} }
......
...@@ -141,8 +141,8 @@ void OverallConfig::CheckParamConflict() { ...@@ -141,8 +141,8 @@ void OverallConfig::CheckParamConflict() {
bool objective_type_multiclass = (objective_type == std::string("multiclass")); bool objective_type_multiclass = (objective_type == std::string("multiclass"));
int num_class_check = boosting_config.num_class; int num_class_check = boosting_config.num_class;
if (objective_type_multiclass) { if (objective_type_multiclass) {
if (num_class_check <= 2) { if (num_class_check <= 1) {
Log::Fatal("Number of classes should be specified and greater than 2 for multiclass training"); Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
} }
} else { } else {
if (task_type == TaskType::kTrain && num_class_check != 1) { if (task_type == TaskType::kTrain && num_class_check != 1) {
......
...@@ -487,7 +487,9 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -487,7 +487,9 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
feature_names_.push_back(str_buf.str()); feature_names_.push_back(str_buf.str());
} }
} }
const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * io_config_.min_data_in_leaf) / num_data * num_col);
const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf * total_sample_size) / num_data);
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -701,7 +703,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -701,7 +703,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
dataset->feature_names_ = feature_names_; dataset->feature_names_ = feature_names_;
std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size()); std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size());
const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * io_config_.min_data_in_leaf) / dataset->num_data_ * sample_values.size()); const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf* sample_values.size()) / dataset->num_data_);
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
...@@ -815,7 +818,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -815,7 +818,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
} }
sample_values.clear(); sample_values.clear();
dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices), dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices).data(),
Common::VectorSize<int>(sample_indices).data(), sample_data.size(), io_config_); Common::VectorSize<int>(sample_indices).data(), sample_data.size(), io_config_);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment