"src/vscode:/vscode.git/clone" did not exist on "34a643f67527b742ce9084de7616301a82ff890e"
Unverified Commit d0ea321c authored by Scott Votaw's avatar Scott Votaw Committed by GitHub
Browse files

Add streaming concurrency tests (#5437)

parent 9dae0e6d
......@@ -185,7 +185,7 @@ void test_stream_sparse(
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;
dataset = static_cast<Dataset*>(dataset_handle);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 1);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2);
break;
}
......@@ -198,6 +198,7 @@ void test_stream_sparse(
dataset = static_cast<Dataset*>(dataset_handle);
Log::Info("Streaming sparse dataset, %d rows sparse data with a batch size of %d", nrows, batch_count);
TestUtils::StreamSparseDataset(
dataset_handle,
nrows,
......@@ -213,7 +214,6 @@ void test_stream_sparse(
dataset->FinishLoad();
Log::Info("Streaming sparse dataset, %d rows sparse data with a batch size of %d", nrows, batch_count);
TestUtils::AssertMetadata(&dataset->metadata(),
labels,
weights,
......@@ -320,7 +320,7 @@ TEST(Stream, PushSparseRowsWithMetadata) {
TestUtils::CreateRandomSparseData(nrows, ncols, nclasses, sparse_percent, &indptr, &indices, &vals, &labels, &weights, &init_scores, &groups);
const std::vector<int32_t> batch_counts = { 1, nrows / 100, nrows / 10, nrows };
const std::vector<int8_t> creation_types = { 0, 1 };
const std::vector<int8_t> creation_types = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
for (size_t i = 0; i < creation_types.size(); ++i) { // from sampled data or reference
for (size_t j = 0; j < batch_counts.size(); ++j) {
......
......@@ -9,6 +9,8 @@
#include <gtest/gtest.h>
#include <string>
#include <thread>
#include <utility>
using LightGBM::Log;
using LightGBM::Random;
......@@ -233,14 +235,6 @@ namespace LightGBM {
weights_ptr = weights->data();
}
// Since init_scores are in a column format, but need to be pushed as rows, we have to extract each batch
std::vector<double> init_score_batch;
const double* init_scores_ptr = nullptr;
if (init_scores) {
init_score_batch.reserve(nclasses * batch_count);
init_scores_ptr = init_score_batch.data();
}
const int32_t* groups_ptr = nullptr;
if (groups) {
groups_ptr = groups->data();
......@@ -248,14 +242,82 @@ namespace LightGBM {
auto start_time = std::chrono::steady_clock::now();
for (int32_t i = 0; i < nrows; i += batch_count) {
// Use multiple threads to test concurrency
int thread_count = 2;
if (nrows == batch_count) {
thread_count = 1; // If pushing all rows in 1 batch, we cannot have multiple threads
}
std::vector<std::thread> threads;
threads.reserve(thread_count);
for (int32_t t = 0; t < thread_count; ++t) {
std::thread th(TestUtils::PushSparseBatch,
dataset_handle,
nrows,
nclasses,
batch_count,
indptr,
indptr_ptr,
indices_ptr,
values_ptr,
labels_ptr,
weights_ptr,
init_scores,
groups_ptr,
thread_count,
t);
threads.push_back(move(th));
}
for (auto& t : threads) t.join();
auto cur_time = std::chrono::steady_clock::now();
Log::Info(" Time: %d", cur_time - start_time);
}
/*!
* Pushes data from 1 thread into a Dataset based on thread_id and nrows.
* e.g. with 100 rows, thread 0 will push rows 0-49, and thread 2 will push rows 50-99.
* Note that rows are still pushed in microbatches within their range.
*/
void TestUtils::PushSparseBatch(DatasetHandle dataset_handle,
int32_t nrows,
int32_t nclasses,
int32_t batch_count,
const std::vector<int32_t>* indptr,
const int32_t* indptr_ptr,
const int32_t* indices_ptr,
const double* values_ptr,
const float* labels_ptr,
const float* weights_ptr,
const std::vector<double>* init_scores,
const int32_t* groups_ptr,
int32_t thread_count,
int32_t thread_id) {
int32_t threadChunkSize = nrows / thread_count;
int32_t startIndex = threadChunkSize * thread_id;
int32_t stopIndex = startIndex + threadChunkSize;
indptr_ptr += threadChunkSize * thread_id;
labels_ptr += threadChunkSize * thread_id;
if (weights_ptr) {
weights_ptr += threadChunkSize * thread_id;
}
if (groups_ptr) {
groups_ptr += threadChunkSize * thread_id;
}
for (int32_t i = startIndex; i < stopIndex; i += batch_count) {
// Since init_scores are in a column format, but need to be pushed as rows, we have to extract each batch
std::vector<double> init_score_batch;
const double* init_scores_ptr = nullptr;
if (init_scores) {
init_score_batch.reserve(nclasses * batch_count);
init_scores_ptr = CreateInitScoreBatch(&init_score_batch, i, nrows, nclasses, batch_count, init_scores);
}
int32_t nelem = indptr->at(i + batch_count - 1) - indptr->at(i);
result = LGBM_DatasetPushRowsByCSRWithMetadata(dataset_handle,
int result = LGBM_DatasetPushRowsByCSRWithMetadata(dataset_handle,
indptr_ptr,
2,
indices_ptr,
......@@ -268,7 +330,7 @@ namespace LightGBM {
weights_ptr,
init_scores_ptr,
groups_ptr,
0);
thread_id);
EXPECT_EQ(0, result) << "LGBM_DatasetPushRowsByCSRWithMetadata result code: " << result;
if (result != 0) {
FAIL() << "LGBM_DatasetPushRowsByCSRWithMetadata failed"; // This forces an immediate failure, which EXPECT_EQ does not
......@@ -283,11 +345,9 @@ namespace LightGBM {
groups_ptr += batch_count;
}
}
auto cur_time = std::chrono::steady_clock::now();
Log::Info(" Time: %d", cur_time - start_time);
}
void TestUtils::AssertMetadata(const Metadata* metadata,
const std::vector<float>* ref_labels,
const std::vector<float>* ref_weights,
......@@ -296,7 +356,7 @@ namespace LightGBM {
const float* labels = metadata->label();
auto nTotal = static_cast<int32_t>(ref_labels->size());
for (auto i = 0; i < nTotal; i++) {
EXPECT_EQ(ref_labels->at(i), labels[i]) << "Inserted data: " << ref_labels->at(i);
EXPECT_EQ(ref_labels->at(i), labels[i]) << "Inserted data: " << ref_labels->at(i) << " at " << i;
if (ref_labels->at(i) != labels[i]) {
FAIL() << "Mismatched labels"; // This forces an immediate failure, which EXPECT_EQ does not
}
......
......@@ -103,6 +103,22 @@ class TestUtils {
int32_t nclasses,
int32_t batch_count,
const std::vector<double>* original_init_scores);
private:
static void PushSparseBatch(DatasetHandle dataset_handle,
int32_t nrows,
int32_t nclasses,
int32_t batch_count,
const std::vector<int32_t>* indptr,
const int32_t* indptr_ptr,
const int32_t* indices_ptr,
const double* values_ptr,
const float* labels_ptr,
const float* weights_ptr,
const std::vector<double>* init_scores,
const int32_t* groups_ptr,
int32_t thread_count,
int32_t thread_id);
};
} // namespace LightGBM
#endif // LIGHTGBM_TESTUTILS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment