Commit 1bade1e2 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in data/voting parallel.

parent 4d6ff287
...@@ -22,10 +22,7 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) { ...@@ -22,10 +22,7 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
rank_ = Network::rank(); rank_ = Network::rank();
num_machines_ = Network::num_machines(); num_machines_ = Network::num_machines();
// allocate buffer for communication // allocate buffer for communication
size_t buffer_size = 0; size_t buffer_size = train_data_->NumTotalBin() * sizeof(HistogramBinEntry);
for (int i = 0; i < num_features_; ++i) {
buffer_size += train_data_->FeatureNumBin(i) * sizeof(HistogramBinEntry);
}
input_buffer_.resize(buffer_size); input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size); output_buffer_.resize(buffer_size);
...@@ -54,7 +51,11 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -54,7 +51,11 @@ void DataParallelTreeLearner::BeforeTrain() {
if (is_feature_used_[i]) { if (is_feature_used_[i]) {
int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed)); int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
feature_distribution[cur_min_machine].push_back(i); feature_distribution[cur_min_machine].push_back(i);
num_bins_distributed[cur_min_machine] += train_data_->FeatureNumBin(i); auto num_bin = train_data_->FeatureNumBin(i);
if (train_data_->FeatureBinMapper(i)->GetDefaultBin() == 0) {
num_bin -= 1;
}
num_bins_distributed[cur_min_machine] += num_bin;
} }
is_feature_aggregated_[i] = false; is_feature_aggregated_[i] = false;
} }
...@@ -68,7 +69,11 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -68,7 +69,11 @@ void DataParallelTreeLearner::BeforeTrain() {
for (int i = 0; i < num_machines_; ++i) { for (int i = 0; i < num_machines_; ++i) {
block_len_[i] = 0; block_len_[i] = 0;
for (auto fid : feature_distribution[i]) { for (auto fid : feature_distribution[i]) {
block_len_[i] += train_data_->FeatureNumBin(fid) * sizeof(HistogramBinEntry); auto num_bin = train_data_->FeatureNumBin(fid);
if (train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
num_bin -= 1;
}
block_len_[i] += num_bin * sizeof(HistogramBinEntry);
} }
reduce_scatter_size_ += block_len_[i]; reduce_scatter_size_ += block_len_[i];
} }
...@@ -83,7 +88,11 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -83,7 +88,11 @@ void DataParallelTreeLearner::BeforeTrain() {
for (int i = 0; i < num_machines_; ++i) { for (int i = 0; i < num_machines_; ++i) {
for (auto fid : feature_distribution[i]) { for (auto fid : feature_distribution[i]) {
buffer_write_start_pos_[fid] = bin_size; buffer_write_start_pos_[fid] = bin_size;
bin_size += train_data_->FeatureNumBin(fid) * sizeof(HistogramBinEntry); auto num_bin = train_data_->FeatureNumBin(fid);
if (train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
num_bin -= 1;
}
bin_size += num_bin * sizeof(HistogramBinEntry);
} }
} }
...@@ -91,7 +100,11 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -91,7 +100,11 @@ void DataParallelTreeLearner::BeforeTrain() {
bin_size = 0; bin_size = 0;
for (auto fid : feature_distribution[rank_]) { for (auto fid : feature_distribution[rank_]) {
buffer_read_start_pos_[fid] = bin_size; buffer_read_start_pos_[fid] = bin_size;
bin_size += train_data_->FeatureNumBin(fid) * sizeof(HistogramBinEntry); auto num_bin = train_data_->FeatureNumBin(fid);
if (train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
num_bin -= 1;
}
bin_size += num_bin * sizeof(HistogramBinEntry);
} }
// sync global data sumup info // sync global data sumup info
...@@ -157,7 +170,7 @@ void DataParallelTreeLearner::FindBestThresholds() { ...@@ -157,7 +170,7 @@ void DataParallelTreeLearner::FindBestThresholds() {
train_data_->FixHistogram(feature_index, train_data_->FixHistogram(feature_index,
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(), smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(), GetGlobalDataCountInLeaf(smaller_leaf_splits_->LeafIndex()),
smaller_leaf_histogram_array_[feature_index].RawData()); smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split; SplitInfo smaller_split;
// find best threshold for smaller child // find best threshold for smaller child
......
...@@ -368,6 +368,12 @@ void VotingParallelTreeLearner::FindBestThresholds() { ...@@ -368,6 +368,12 @@ void VotingParallelTreeLearner::FindBestThresholds() {
// restore from buffer // restore from buffer
smaller_leaf_histogram_array_global_[feature_index].FromMemory( smaller_leaf_histogram_array_global_[feature_index].FromMemory(
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]); output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
train_data_->FixHistogram(feature_index,
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold // find best threshold
smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold( smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold(
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_gradients(),
...@@ -383,6 +389,12 @@ void VotingParallelTreeLearner::FindBestThresholds() { ...@@ -383,6 +389,12 @@ void VotingParallelTreeLearner::FindBestThresholds() {
SplitInfo larger_split; SplitInfo larger_split;
// restore from buffer // restore from buffer
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]); larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
train_data_->FixHistogram(feature_index,
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold // find best threshold
larger_leaf_histogram_array_global_[feature_index].FindBestThreshold( larger_leaf_histogram_array_global_[feature_index].FindBestThreshold(
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_gradients(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment