Commit 1f91449d authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

first fixes for gemm

parent 824809c1
......@@ -57,8 +57,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
printf("PrefetchStages: %d\n", BaseGemmPipeline::PrefetchStages);
printf("num_loop: %d\n", num_loop);
printf("has_hot_loop: %d\n", has_hot_loop);
printf("tail_num: %d\n", static_cast<int>(tail_num));
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
......@@ -86,7 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
if(true)
{
std::cout << "Lunching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
......@@ -169,7 +174,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
} // what if not?
}
return ave_time;
......
......@@ -67,6 +67,7 @@ int run_gemm_example(int argc, char* argv[])
int n_repeat = arg_parser.get_int("repeat");
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
//using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
......
......@@ -166,7 +166,7 @@
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#define CK_TILE_DEBUG_LOG 1
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
......
......@@ -138,6 +138,11 @@ struct BlockGemmASmemBSmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
// if(threadIdx.x == 0) {
// printf("block gemm\n");
// }
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
......@@ -162,6 +167,12 @@ struct BlockGemmASmemBSmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// if(threadIdx.x == 0) {
// printf("C warp\n");
// tile_elementwise_inout([](auto& c) { printf("%f ", static_cast<float>(c));}, c_block_tensor);
// printf("\n");
// }
});
});
});
......
......@@ -276,10 +276,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if(threadIdx.x == 0) {
printf("gemm_pipeline_ag_bg_cr_mem\n");
printf("A in: ");
static_for<0, 16, 1>{}([&](auto i) {
printf("%f ", static_cast<float>(a_block_tiles.get(I0{}).get_thread_buffer()[i]));
});
printf("\nB in: ");
static_for<0, 16, 1>{}([&](auto i) {
printf("%f ", static_cast<float>(b_block_tiles.get(I0{}).get_thread_buffer()[i]));
});
printf("\n");
}
// LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// print a_block_tiles, b_block_tiles
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
......@@ -341,6 +356,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
if constexpr(TailNum == TailNumber::One)
{
//printf("TailNumOne\n");
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
......
......@@ -43,6 +43,18 @@ struct WarpGemmImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// if(threadIdx.x == 0) {
// for(int i=0; i<AWarpTensor::get_thread_buffer_size(); ++i) {
// printf("A[%d]: %d\n", i, static_cast<int32_t>(a_vec[i]));
// }
// for(int i=0; i<BWarpTensor::get_thread_buffer_size(); ++i) {
// printf("B[%d]: %d\n", i, static_cast<int32_t>(b_vec[i]));
// }
// for(int i=0; i<CWarpTensor::get_thread_buffer_size(); ++i) {
// printf("C[%d]: %d\n", i, static_cast<int32_t>(c_vec[i]));
// }
// }
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
......
......@@ -18,6 +18,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Col, Row, F16, F16, F32, F16>
//std::tuple< Row, Row, Row, F16, F16, F32, F16>
// TODO: fixme!
// std::tuple< Col, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
......
......@@ -2,42 +2,64 @@
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
constexpr int K = 320;
std::vector<int> Ms{128};
std::vector<int> Ns{128}; // M K K N M N
std::vector<int> Ks{33};
for(int M : Ms)
this->Run(M, N, K);
for(int N : Ns)
for(int K : Ks)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
// TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
// {
// std::vector<int> Ms{127, 255, 312, 799, 1573};
// constexpr int N = 1024;
// constexpr int K = 321;
for(int M : Ms)
this->Run(M, N, K);
}
// for(int M : Ms)
// this->Run(M, N, K);
// }
// TODO: Seems like padding is not working!
// Works only when K is a multiple of KPerBlock
TYPED_TEST(TestCkTileGemmMemPipeline, DISABLED_PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 1024;
constexpr int K = 432;
// TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
// {
// std::vector<int> Ms{1};
// constexpr int N = 128;
// constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
// for(int M : Ms)
// this->Run(M, N, K);
// }
TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 512;
// TYPED_TEST(TestCkTileGemmMemPipeline, PaddKInv)
// {
// std::vector<int> Ms{1};
// constexpr int N = 128;
// constexpr int K = 322;
for(int M : Ms)
this->Run(M, N, K);
}
// for(int M : Ms)
// this->Run(M, N, K);
// }
// TYPED_TEST(TestCkTileGemmMemPipeline, PaddKInv2)
// {
// std::vector<int> Ms{1};
// constexpr int N = 128;
// constexpr int K = 346;
// for(int M : Ms)
// this->Run(M, N, K);
// }
// TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
// {
// std::vector<int> Ms{512};
// constexpr int N = 1024;
// constexpr int K = 512;
// for(int M : Ms)
// this->Run(M, N, K);
// }
......@@ -78,6 +78,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
std::cout << "has hot loop " << has_hot_loop << std::endl;
std::cout << "num loop " << num_loop << std::endl;
std::cout << "tail_num " << static_cast<int32_t>(tail_num) - 1 << std::endl;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
......@@ -105,7 +109,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
if(true)
{
std::cout << "Lunching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
......@@ -119,14 +123,17 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if(has_hot_loop)
{
std::cout << "has hot loop xx\n";
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
std::cout << "tail num one\n";
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
std::cout << "tail num full\n";
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
......@@ -191,6 +198,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
// Tail number always 1
if(tail_num == ck_tile::TailNumber::One)
{
std::cout << "nohotloop tail num one xx\n";
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
......@@ -267,8 +275,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
ck_tile::FillMonotonicSeq<ADataType>{0, 0.01}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{0, 0.01}(b_k_n);
//ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
//ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
......@@ -291,6 +301,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args.stride_B = stride_B;
args.stride_C = stride_C;
std::cout << "kbatch " << kbatch << std::endl;
std::cout << "stride A " << stride_A << std::endl;
std::cout << "stride B " << stride_B << std::endl;
std::cout << "stride C " << stride_C << std::endl;
invoke_gemm(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment