Unverified Commit d0ea321c authored by Scott Votaw's avatar Scott Votaw Committed by GitHub
Browse files

Add streaming concurrency tests (#5437)

parent 9dae0e6d
...@@ -185,7 +185,7 @@ void test_stream_sparse( ...@@ -185,7 +185,7 @@ void test_stream_sparse(
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result; EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;
dataset = static_cast<Dataset*>(dataset_handle); dataset = static_cast<Dataset*>(dataset_handle);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 1); dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2);
break; break;
} }
...@@ -198,6 +198,7 @@ void test_stream_sparse( ...@@ -198,6 +198,7 @@ void test_stream_sparse(
dataset = static_cast<Dataset*>(dataset_handle); dataset = static_cast<Dataset*>(dataset_handle);
Log::Info("Streaming sparse dataset, %d rows sparse data with a batch size of %d", nrows, batch_count);
TestUtils::StreamSparseDataset( TestUtils::StreamSparseDataset(
dataset_handle, dataset_handle,
nrows, nrows,
...@@ -213,7 +214,6 @@ void test_stream_sparse( ...@@ -213,7 +214,6 @@ void test_stream_sparse(
dataset->FinishLoad(); dataset->FinishLoad();
Log::Info("Streaming sparse dataset, %d rows sparse data with a batch size of %d", nrows, batch_count);
TestUtils::AssertMetadata(&dataset->metadata(), TestUtils::AssertMetadata(&dataset->metadata(),
labels, labels,
weights, weights,
...@@ -320,7 +320,7 @@ TEST(Stream, PushSparseRowsWithMetadata) { ...@@ -320,7 +320,7 @@ TEST(Stream, PushSparseRowsWithMetadata) {
TestUtils::CreateRandomSparseData(nrows, ncols, nclasses, sparse_percent, &indptr, &indices, &vals, &labels, &weights, &init_scores, &groups); TestUtils::CreateRandomSparseData(nrows, ncols, nclasses, sparse_percent, &indptr, &indices, &vals, &labels, &weights, &init_scores, &groups);
const std::vector<int32_t> batch_counts = { 1, nrows / 100, nrows / 10, nrows }; const std::vector<int32_t> batch_counts = { 1, nrows / 100, nrows / 10, nrows };
const std::vector<int8_t> creation_types = { 0, 1 }; const std::vector<int8_t> creation_types = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
for (size_t i = 0; i < creation_types.size(); ++i) { // from sampled data or reference for (size_t i = 0; i < creation_types.size(); ++i) { // from sampled data or reference
for (size_t j = 0; j < batch_counts.size(); ++j) { for (size_t j = 0; j < batch_counts.size(); ++j) {
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
#include <thread>
#include <utility>
using LightGBM::Log; using LightGBM::Log;
using LightGBM::Random; using LightGBM::Random;
...@@ -233,14 +235,6 @@ namespace LightGBM { ...@@ -233,14 +235,6 @@ namespace LightGBM {
weights_ptr = weights->data(); weights_ptr = weights->data();
} }
// Since init_scores are in a column format, but need to be pushed as rows, we have to extract each batch
std::vector<double> init_score_batch;
const double* init_scores_ptr = nullptr;
if (init_scores) {
init_score_batch.reserve(nclasses * batch_count);
init_scores_ptr = init_score_batch.data();
}
const int32_t* groups_ptr = nullptr; const int32_t* groups_ptr = nullptr;
if (groups) { if (groups) {
groups_ptr = groups->data(); groups_ptr = groups->data();
...@@ -248,14 +242,82 @@ namespace LightGBM { ...@@ -248,14 +242,82 @@ namespace LightGBM {
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
for (int32_t i = 0; i < nrows; i += batch_count) { // Use multiple threads to test concurrency
int thread_count = 2;
if (nrows == batch_count) {
thread_count = 1; // If pushing all rows in 1 batch, we cannot have multiple threads
}
std::vector<std::thread> threads;
threads.reserve(thread_count);
for (int32_t t = 0; t < thread_count; ++t) {
std::thread th(TestUtils::PushSparseBatch,
dataset_handle,
nrows,
nclasses,
batch_count,
indptr,
indptr_ptr,
indices_ptr,
values_ptr,
labels_ptr,
weights_ptr,
init_scores,
groups_ptr,
thread_count,
t);
threads.push_back(move(th));
}
for (auto& t : threads) t.join();
auto cur_time = std::chrono::steady_clock::now();
Log::Info(" Time: %d", cur_time - start_time);
}
/*!
* Pushes data from 1 thread into a Dataset based on thread_id and nrows.
* e.g. with 100 rows, thread 0 will push rows 0-49, and thread 2 will push rows 50-99.
* Note that rows are still pushed in microbatches within their range.
*/
void TestUtils::PushSparseBatch(DatasetHandle dataset_handle,
int32_t nrows,
int32_t nclasses,
int32_t batch_count,
const std::vector<int32_t>* indptr,
const int32_t* indptr_ptr,
const int32_t* indices_ptr,
const double* values_ptr,
const float* labels_ptr,
const float* weights_ptr,
const std::vector<double>* init_scores,
const int32_t* groups_ptr,
int32_t thread_count,
int32_t thread_id) {
int32_t threadChunkSize = nrows / thread_count;
int32_t startIndex = threadChunkSize * thread_id;
int32_t stopIndex = startIndex + threadChunkSize;
indptr_ptr += threadChunkSize * thread_id;
labels_ptr += threadChunkSize * thread_id;
if (weights_ptr) {
weights_ptr += threadChunkSize * thread_id;
}
if (groups_ptr) {
groups_ptr += threadChunkSize * thread_id;
}
for (int32_t i = startIndex; i < stopIndex; i += batch_count) {
// Since init_scores are in a column format, but need to be pushed as rows, we have to extract each batch
std::vector<double> init_score_batch;
const double* init_scores_ptr = nullptr;
if (init_scores) { if (init_scores) {
init_score_batch.reserve(nclasses * batch_count);
init_scores_ptr = CreateInitScoreBatch(&init_score_batch, i, nrows, nclasses, batch_count, init_scores); init_scores_ptr = CreateInitScoreBatch(&init_score_batch, i, nrows, nclasses, batch_count, init_scores);
} }
int32_t nelem = indptr->at(i + batch_count - 1) - indptr->at(i); int32_t nelem = indptr->at(i + batch_count - 1) - indptr->at(i);
result = LGBM_DatasetPushRowsByCSRWithMetadata(dataset_handle, int result = LGBM_DatasetPushRowsByCSRWithMetadata(dataset_handle,
indptr_ptr, indptr_ptr,
2, 2,
indices_ptr, indices_ptr,
...@@ -268,7 +330,7 @@ namespace LightGBM { ...@@ -268,7 +330,7 @@ namespace LightGBM {
weights_ptr, weights_ptr,
init_scores_ptr, init_scores_ptr,
groups_ptr, groups_ptr,
0); thread_id);
EXPECT_EQ(0, result) << "LGBM_DatasetPushRowsByCSRWithMetadata result code: " << result; EXPECT_EQ(0, result) << "LGBM_DatasetPushRowsByCSRWithMetadata result code: " << result;
if (result != 0) { if (result != 0) {
FAIL() << "LGBM_DatasetPushRowsByCSRWithMetadata failed"; // This forces an immediate failure, which EXPECT_EQ does not FAIL() << "LGBM_DatasetPushRowsByCSRWithMetadata failed"; // This forces an immediate failure, which EXPECT_EQ does not
...@@ -283,11 +345,9 @@ namespace LightGBM { ...@@ -283,11 +345,9 @@ namespace LightGBM {
groups_ptr += batch_count; groups_ptr += batch_count;
} }
} }
auto cur_time = std::chrono::steady_clock::now();
Log::Info(" Time: %d", cur_time - start_time);
} }
void TestUtils::AssertMetadata(const Metadata* metadata, void TestUtils::AssertMetadata(const Metadata* metadata,
const std::vector<float>* ref_labels, const std::vector<float>* ref_labels,
const std::vector<float>* ref_weights, const std::vector<float>* ref_weights,
...@@ -296,7 +356,7 @@ namespace LightGBM { ...@@ -296,7 +356,7 @@ namespace LightGBM {
const float* labels = metadata->label(); const float* labels = metadata->label();
auto nTotal = static_cast<int32_t>(ref_labels->size()); auto nTotal = static_cast<int32_t>(ref_labels->size());
for (auto i = 0; i < nTotal; i++) { for (auto i = 0; i < nTotal; i++) {
EXPECT_EQ(ref_labels->at(i), labels[i]) << "Inserted data: " << ref_labels->at(i); EXPECT_EQ(ref_labels->at(i), labels[i]) << "Inserted data: " << ref_labels->at(i) << " at " << i;
if (ref_labels->at(i) != labels[i]) { if (ref_labels->at(i) != labels[i]) {
FAIL() << "Mismatched labels"; // This forces an immediate failure, which EXPECT_EQ does not FAIL() << "Mismatched labels"; // This forces an immediate failure, which EXPECT_EQ does not
} }
......
...@@ -103,6 +103,22 @@ class TestUtils { ...@@ -103,6 +103,22 @@ class TestUtils {
int32_t nclasses, int32_t nclasses,
int32_t batch_count, int32_t batch_count,
const std::vector<double>* original_init_scores); const std::vector<double>* original_init_scores);
private:
static void PushSparseBatch(DatasetHandle dataset_handle,
int32_t nrows,
int32_t nclasses,
int32_t batch_count,
const std::vector<int32_t>* indptr,
const int32_t* indptr_ptr,
const int32_t* indices_ptr,
const double* values_ptr,
const float* labels_ptr,
const float* weights_ptr,
const std::vector<double>* init_scores,
const int32_t* groups_ptr,
int32_t thread_count,
int32_t thread_id);
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LIGHTGBM_TESTUTILS_H_ #endif // LIGHTGBM_TESTUTILS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment