Commit 798670d8 authored by rocking's avatar rocking
Browse files

Fix bug of concurrency and add test case which may fail orignally

parent e48ddb6a
...@@ -297,9 +297,14 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -297,9 +297,14 @@ struct GridwiseLayernorm_mk_to_mk
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
......
...@@ -169,7 +169,7 @@ class TestLayernorm : public ::testing::Test ...@@ -169,7 +169,7 @@ class TestLayernorm : public ::testing::Test
} }
std::vector<std::vector<index_t>> lengths_ = { std::vector<std::vector<index_t>> lengths_ = {
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}}; {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
std::vector<std::vector<index_t>> reduceDims_ = {{1}}; std::vector<std::vector<index_t>> reduceDims_ = {{1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment