Commit 1f91449d authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

first fixes for gemm

parent 824809c1
...@@ -57,8 +57,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -57,8 +57,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0}; printf("PrefetchStages: %d\n", BaseGemmPipeline::PrefetchStages);
printf("num_loop: %d\n", num_loop);
printf("has_hot_loop: %d\n", has_hot_loop);
printf("tail_num: %d\n", static_cast<int>(tail_num));
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
...@@ -86,7 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -86,7 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0) if(true)
{ {
std::cout << "Lunching kernel with args:" std::cout << "Lunching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
...@@ -169,7 +174,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -169,7 +174,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
Run(ck_tile::bool_constant<false>{}, Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} } // what if not?
} }
return ave_time; return ave_time;
......
...@@ -67,6 +67,7 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -67,6 +67,7 @@ int run_gemm_example(int argc, char* argv[])
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
//using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
......
...@@ -166,7 +166,7 @@ ...@@ -166,7 +166,7 @@
#endif #endif
#ifndef CK_TILE_DEBUG_LOG #ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0 #define CK_TILE_DEBUG_LOG 1
#endif #endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
......
...@@ -138,6 +138,11 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -138,6 +138,11 @@ struct BlockGemmASmemBSmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop: // hot loop:
// if(threadIdx.x == 0) {
// printf("block gemm\n");
// }
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A block window
...@@ -162,6 +167,12 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -162,6 +167,12 @@ struct BlockGemmASmemBSmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros), merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer()); c_warp_tensor.get_thread_buffer());
// if(threadIdx.x == 0) {
// printf("C warp\n");
// tile_elementwise_inout([](auto& c) { printf("%f ", static_cast<float>(c));}, c_block_tensor);
// printf("\n");
// }
}); });
}); });
}); });
......
...@@ -276,10 +276,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -276,10 +276,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if(threadIdx.x == 0) {
printf("gemm_pipeline_ag_bg_cr_mem\n");
printf("A in: ");
static_for<0, 16, 1>{}([&](auto i) {
printf("%f ", static_cast<float>(a_block_tiles.get(I0{}).get_thread_buffer()[i]));
});
printf("\nB in: ");
static_for<0, 16, 1>{}([&](auto i) {
printf("%f ", static_cast<float>(b_block_tiles.get(I0{}).get_thread_buffer()[i]));
});
printf("\n");
}
// LDS write 0 // LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// print a_block_tiles, b_block_tiles
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
...@@ -341,6 +356,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -341,6 +356,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
if constexpr(TailNum == TailNumber::One) if constexpr(TailNum == TailNumber::One)
{ {
//printf("TailNumOne\n");
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); // block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
......
...@@ -43,6 +43,18 @@ struct WarpGemmImpl ...@@ -43,6 +43,18 @@ struct WarpGemmImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0]; const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0]; auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// if(threadIdx.x == 0) {
// for(int i=0; i<AWarpTensor::get_thread_buffer_size(); ++i) {
// printf("A[%d]: %d\n", i, static_cast<int32_t>(a_vec[i]));
// }
// for(int i=0; i<BWarpTensor::get_thread_buffer_size(); ++i) {
// printf("B[%d]: %d\n", i, static_cast<int32_t>(b_vec[i]));
// }
// for(int i=0; i<CWarpTensor::get_thread_buffer_size(); ++i) {
// printf("C[%d]: %d\n", i, static_cast<int32_t>(c_vec[i]));
// }
// }
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec); WarpGemmAttribute{}(c_vec, a_vec, b_vec);
......
...@@ -18,6 +18,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -18,6 +18,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Col, Row, F16, F16, F32, F16> std::tuple< Row, Col, Row, F16, F16, F32, F16>
//std::tuple< Row, Row, Row, F16, F16, F32, F16>
// TODO: fixme! // TODO: fixme!
// std::tuple< Col, Row, Row, F16, F16, F32, F16>, // std::tuple< Col, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>, // std::tuple< Row, Row, Row, F16, F16, F32, F16>,
......
...@@ -2,42 +2,64 @@ ...@@ -2,42 +2,64 @@
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; std::vector<int> Ms{128};
constexpr int N = 1024; std::vector<int> Ns{128}; // M K K N M N
constexpr int K = 320; std::vector<int> Ks{33};
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); for(int N : Ns)
for(int K : Ks)
this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) // TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
{ // {
std::vector<int> Ms{127, 255, 312, 799, 1573}; // std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; // constexpr int N = 1024;
constexpr int K = 320; // constexpr int K = 321;
for(int M : Ms) // for(int M : Ms)
this->Run(M, N, K); // this->Run(M, N, K);
} // }
// TODO: Seems like padding is not working! // TODO: Seems like padding is not working!
// Works only when K is a multiple of KPerBlock // Works only when K is a multiple of KPerBlock
TYPED_TEST(TestCkTileGemmMemPipeline, DISABLED_PaddK) // TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
{ // {
std::vector<int> Ms{127}; // std::vector<int> Ms{1};
constexpr int N = 1024; // constexpr int N = 128;
constexpr int K = 432; // constexpr int K = 320;
for(int M : Ms) // for(int M : Ms)
this->Run(M, N, K); // this->Run(M, N, K);
} // }
TYPED_TEST(TestCkTileGemmMemPipeline, Regular) // TYPED_TEST(TestCkTileGemmMemPipeline, PaddKInv)
{ // {
std::vector<int> Ms{512}; // std::vector<int> Ms{1};
constexpr int N = 1024; // constexpr int N = 128;
constexpr int K = 512; // constexpr int K = 322;
for(int M : Ms) // for(int M : Ms)
this->Run(M, N, K); // this->Run(M, N, K);
} // }
// TYPED_TEST(TestCkTileGemmMemPipeline, PaddKInv2)
// {
// std::vector<int> Ms{1};
// constexpr int N = 128;
// constexpr int K = 346;
// for(int M : Ms)
// this->Run(M, N, K);
// }
// TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
// {
// std::vector<int> Ms{512};
// constexpr int N = 1024;
// constexpr int K = 512;
// for(int M : Ms)
// this->Run(M, N, K);
// }
...@@ -78,6 +78,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -78,6 +78,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
std::cout << "has hot loop " << has_hot_loop << std::endl;
std::cout << "num loop " << num_loop << std::endl;
std::cout << "tail_num " << static_cast<int32_t>(tail_num) - 1 << std::endl;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
...@@ -105,7 +109,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -105,7 +109,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0) if(true)
{ {
std::cout << "Lunching kernel with args:" std::cout << "Lunching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
...@@ -119,14 +123,17 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -119,14 +123,17 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
std::cout << "has hot loop xx\n";
// Tail pipeline One to Seven // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
std::cout << "tail num one\n";
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} }
else if(tail_num == ck_tile::TailNumber::Full) else if(tail_num == ck_tile::TailNumber::Full)
{ {
std::cout << "tail num full\n";
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
} }
...@@ -191,6 +198,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -191,6 +198,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
// Tail number always 1 // Tail number always 1
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
std::cout << "nohotloop tail num one xx\n";
Run(ck_tile::bool_constant<false>{}, Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} }
...@@ -267,8 +275,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -267,8 +275,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k); ck_tile::FillMonotonicSeq<ADataType>{0, 0.01}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n); ck_tile::FillMonotonicSeq<BDataType>{0, 0.01}(b_k_n);
//ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
//ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
...@@ -291,6 +301,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -291,6 +301,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
std::cout << "kbatch " << kbatch << std::endl;
std::cout << "stride A " << stride_A << std::endl;
std::cout << "stride B " << stride_B << std::endl;
std::cout << "stride C " << stride_C << std::endl;
invoke_gemm(args, ck_tile::stream_config{nullptr, false}); invoke_gemm(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment