Commit 9c5dbdde authored by Guolin Ke's avatar Guolin Ke
Browse files

[bug fix] fix predict sigmoid; fix bagging bug.

parent fd0cbe65
...@@ -129,11 +129,13 @@ public: ...@@ -129,11 +129,13 @@ public:
} }
/*! /*!
* \brief Construct feature value to bin mapper according feature values * \brief Construct feature value to bin mapper according feature values
* \param values (Sampled) values of this feature * \param column_name name of this column
* \param values (Sampled) values of this feature, Note: not include zero.
* \param total_sample_cnt number of total sample count, equal with values.size() + num_zeros
* \param max_bin The maximal number of bin * \param max_bin The maximal number of bin
* \param bin_type Type of this bin * \param bin_type Type of this bin
*/ */
void FindBin(std::vector<double>* values, size_t total_sample_cnt, int max_bin, BinType bin_type); void FindBin(const std::string& column_name, std::vector<double>* values, size_t total_sample_cnt, int max_bin, BinType bin_type);
/*! /*!
* \brief Use specific number of bin to calculate the size of this class * \brief Use specific number of bin to calculate the size of this class
......
...@@ -25,7 +25,7 @@ GBDT::GBDT() ...@@ -25,7 +25,7 @@ GBDT::GBDT()
early_stopping_round_(0), early_stopping_round_(0),
max_feature_idx_(0), max_feature_idx_(0),
num_class_(1), num_class_(1),
sigmoid_(1.0f), sigmoid_(-1.0f),
num_iteration_for_pred_(0), num_iteration_for_pred_(0),
shrinkage_rate_(0.1f), shrinkage_rate_(0.1f),
num_init_iteration_(0) { num_init_iteration_(0) {
...@@ -187,6 +187,9 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -187,6 +187,9 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
} }
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer){ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer){
if (cnt <= 0) {
return 0;
}
data_size_t bag_data_cnt = data_size_t bag_data_cnt =
static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt); static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
data_size_t cur_left_cnt = 0; data_size_t cur_left_cnt = 0;
...@@ -492,7 +495,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -492,7 +495,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
} else if(sigmoid_ > 0.0f){ } else if(sigmoid_ > 0.0f){
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
out_result[i] = static_cast<double>(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i]))); out_result[i] = static_cast<double>(1.0f / (1.0f + std::exp(- sigmoid_ * raw_scores[i])));
} }
} else { } else {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -761,7 +764,7 @@ std::vector<double> GBDT::Predict(const double* value) const { ...@@ -761,7 +764,7 @@ std::vector<double> GBDT::Predict(const double* value) const {
} }
// if need sigmoid transform // if need sigmoid transform
if (sigmoid_ > 0 && num_class_ == 1) { if (sigmoid_ > 0 && num_class_ == 1) {
ret[0] = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret[0])); ret[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * ret[0]));
} else if (num_class_ > 1) { } else if (num_class_ > 1) {
Common::Softmax(&ret); Common::Softmax(&ret);
} }
......
...@@ -41,7 +41,7 @@ BinMapper::~BinMapper() { ...@@ -41,7 +41,7 @@ BinMapper::~BinMapper() {
} }
void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, int max_bin, BinType bin_type) { void BinMapper::FindBin(const std::string& column_name, std::vector<double>* values, size_t total_sample_cnt, int max_bin, BinType bin_type) {
bin_type_ = bin_type; bin_type_ = bin_type;
std::vector<double>& ref_values = (*values); std::vector<double>& ref_values = (*values);
size_t sample_size = total_sample_cnt; size_t sample_size = total_sample_cnt;
...@@ -181,7 +181,7 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -181,7 +181,7 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
} }
if (used_cnt / static_cast<double>(sample_size) < 0.95f) { if (used_cnt / static_cast<double>(sample_size) < 0.95f) {
Log::Warning("Too many categoricals are ignored, \ Log::Warning("Too many categoricals are ignored, \
please use bigger max_bin or partition this column "); please use bigger max_bin or partition column \"%s\" ", column_name.c_str());
} }
cnt_in_bin0 = static_cast<int>(sample_size) - used_cnt + counts_int[0]; cnt_in_bin0 = static_cast<int>(sample_size) - used_cnt + counts_int[0];
} }
......
...@@ -433,6 +433,14 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -433,6 +433,14 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& sample_values, size_t total_sample_size, data_size_t num_data) { Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& sample_values, size_t total_sample_size, data_size_t num_data) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size()); std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size());
// fill feature_names_ if not header
if (feature_names_.empty()) {
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
std::stringstream str_buf;
str_buf << "Column_" << i;
feature_names_.push_back(str_buf.str());
}
}
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
...@@ -440,7 +448,7 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& ...@@ -440,7 +448,7 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
if (categorical_features_.count(i)) { if (categorical_features_.count(i)) {
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
} }
bin_mappers[i]->FindBin(&sample_values[i], total_sample_size, io_config_.max_bin, bin_type); bin_mappers[i]->FindBin(feature_names_[i], &sample_values[i], total_sample_size, io_config_.max_bin, bin_type);
} }
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
...@@ -467,14 +475,6 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& ...@@ -467,14 +475,6 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
} }
} }
dataset->features_.shrink_to_fit(); dataset->features_.shrink_to_fit();
// fill feature_names_ if not header
if (feature_names_.empty()) {
for (int i = 0; i < dataset->num_total_features_; ++i) {
std::stringstream str_buf;
str_buf << "Column_" << i;
feature_names_.push_back(str_buf.str());
}
}
dataset->feature_names_ = feature_names_; dataset->feature_names_ = feature_names_;
dataset->num_features_ = static_cast<int>(dataset->features_.size()); dataset->num_features_ = static_cast<int>(dataset->features_.size());
dataset->metadata_.Init(dataset->num_data_, NO_SPECIFIC, NO_SPECIFIC); dataset->metadata_.Init(dataset->num_data_, NO_SPECIFIC, NO_SPECIFIC);
...@@ -668,7 +668,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -668,7 +668,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
if (categorical_features_.count(i)) { if (categorical_features_.count(i)) {
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
} }
bin_mappers[i]->FindBin(&sample_values[i], sample_data.size(), io_config_.max_bin, bin_type); bin_mappers[i]->FindBin(feature_names_[i], &sample_values[i], sample_data.size(), io_config_.max_bin, bin_type);
} }
for (size_t i = 0; i < sample_values.size(); ++i) { for (size_t i = 0; i < sample_values.size(); ++i) {
...@@ -722,7 +722,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -722,7 +722,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
if (categorical_features_.count(start[rank] + i)) { if (categorical_features_.count(start[rank] + i)) {
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
} }
bin_mapper.FindBin(&sample_values[start[rank] + i], sample_data.size(), io_config_.max_bin, bin_type); bin_mapper.FindBin(feature_names_[start[rank] + i], &sample_values[start[rank] + i], sample_data.size(), io_config_.max_bin, bin_type);
bin_mapper.CopyTo(input_buffer.data() + i * type_size); bin_mapper.CopyTo(input_buffer.data() + i * type_size);
} }
// convert to binary size // convert to binary size
......
...@@ -63,7 +63,7 @@ public: ...@@ -63,7 +63,7 @@ public:
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
double prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i])); double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[i]));
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
} }
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
double prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i])); double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[i]));
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment