Commit e511bb78 authored by coderfeli's avatar coderfeli
Browse files

lds a,b ok

parent d51f4e52
...@@ -503,10 +503,6 @@ include_directories(BEFORE ...@@ -503,10 +503,6 @@ include_directories(BEFORE
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror # -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -117,6 +117,10 @@ int run_gemm_example_with_layouts(int argc, ...@@ -117,6 +117,10 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
// ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
// ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
......
...@@ -211,15 +211,16 @@ struct FillNormalDistributionIntegerValue ...@@ -211,15 +211,16 @@ struct FillNormalDistributionIntegerValue
template <typename T> template <typename T>
struct FillMonotonicSeq struct FillMonotonicSeq
{ {
T init_value_{0}; T init_value_{-1024};
T step_{1}; T step_{1};
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
std::generate(first, last, [=, n = init_value_]() mutable { T step_start = init_value_;
std::generate(first, last, [&, n = init_value_]() mutable {
auto tmp = n; auto tmp = n;
n += step_; n += step_;
if (n > step_start + 2047) {step_start += step_; n = step_start;}
return tmp; return tmp;
}); });
} }
......
...@@ -42,9 +42,6 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -42,9 +42,6 @@ struct BlockGemmASmemBSmemCRegV1
KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
"wrong!"); "wrong!");
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock);
// }
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
...@@ -56,12 +53,12 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -56,12 +53,12 @@ struct BlockGemmASmemBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; // constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; // constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; // constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp; // const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp; // const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x == 0 && blockIdx.x==0) { // if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter); // printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
...@@ -69,91 +66,69 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -69,91 +66,69 @@ struct BlockGemmASmemBSmemCRegV1
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8 // MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}), make_tuple(MPerBlock, KPerBlock),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, {0, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); Policy::template MakeALDSTileDistribution<Problem>());
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}), make_tuple(NPerBlock, KPerBlock),
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, {0, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); Policy::template MakeBLDSTileDistribution<Problem>());
#if 0 // FIXME: using array will cause register spill auto a_block_tensor = load_tile(a_warp_window_tmp);
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{ auto b_block_tensor = load_tile(b_warp_window_tmp);
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) // if (threadIdx.x == 0) {
{ // printf("0\n");
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) // constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
{ // sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
move_tile_window(b_warp_windows(nIter)(kIter), // sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); // constexpr auto i_j_idx = make_tuple(idx0, idx1);
} // printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
} // });
#else // printf("\n");
statically_indexed_array< // });
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>, // }
NIterPerWarp> // __syncthreads();
b_warp_windows; using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
using CWarpDstr = typename WG::CWarpDstr; using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor; using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop: // hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A Block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B block tensor
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor // read C warp tensor from C block tensor
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -173,6 +148,36 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -173,6 +148,36 @@ struct BlockGemmASmemBSmemCRegV1
}); });
}); });
}); });
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// // hot loop:
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// // read A warp tensor from A block window
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// // read B warp tensor from B Block window
// // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// // read C warp tensor from C block tensor
// CWarpTensor c_warp_tensor;
// c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// // warp GEMM
// WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter));
// // write C warp tensor into C block tensor
// c_block_tensor.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensor.get_thread_buffer());
// });
// });
// });
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
...@@ -217,5 +222,72 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -217,5 +222,72 @@ struct BlockGemmASmemBSmemCRegV1
return c_block_tensor; return c_block_tensor;
} }
}; };
// construct A-warp-window
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
// {a_warp_window_tmp}};
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
// construct B-warp-window
// auto b_warp_window_tmp = make_tile_window(
// b_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
// {b_warp_window_tmp}};
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
} // namespace ck_tile } // namespace ck_tile
...@@ -40,7 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -40,7 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
} }
#else #else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 2, 2); return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); // return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif #endif
} }
...@@ -55,6 +55,96 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -55,6 +55,96 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
static_assert(false, "Unsupported data type configuration for GEMM warp execution."); static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALDSTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(ADataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t M2 = 32; // MPERXDL
constexpr index_t M1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLDSTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(BDataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t N2 = 32; // MPERXDL
constexpr index_t N1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -133,7 +133,16 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -133,7 +133,16 @@ struct GemmPipelineAGmemBGmemCRegV1
// global read 0 // global read 0
auto a_block_tile = load_tile(a_copy_dram_window); auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window); auto b_block_tile = load_tile(b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
{ {
// move to 1 // move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock}); move_tile_window(a_copy_dram_window, {0, kKPerBlock});
...@@ -170,7 +179,17 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -170,7 +179,17 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
} }
} }
// __syncthreads();
// if (threadIdx.x == 0) {
// for (int j = 0; j < 256; j++) {
// for(int i = 0; i < 32; i++) {
// int ik0 = i /8;
// int ik1 = i % 8;
// printf("%f,", type_convert<float>(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256]));
// }
// printf("\n");
// }
// }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
while(iCounter > 0) while(iCounter > 0)
{ {
...@@ -219,6 +238,17 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -219,6 +238,17 @@ struct GemmPipelineAGmemBGmemCRegV1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
} }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
return c_block_tile; return c_block_tile;
} }
......
...@@ -54,7 +54,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -54,7 +54,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
...@@ -77,7 +77,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -77,7 +77,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
...@@ -130,74 +130,74 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -130,74 +130,74 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
#elif 1 #elif 1
// fake XOR // fake XOR
template <typename Problem> // template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() // CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ // {
using namespace ck_tile; // using namespace ck_tile;
using ADataType = remove_cvref_t<typename Problem::ADataType>; // using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; // constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; // constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( // constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}), // make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{}); // number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType); // constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( // constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3, // a_lds_block_desc_d1_d2_d3,
make_tuple( // make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1), // make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)), // make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}), // make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{})); // make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( // constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6, // a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})), // make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)), // make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), // make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); // make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc_m_k; // return a_lds_block_desc_m_k;
} // }
// fake XOR // // fake XOR
template <typename Problem> // template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() // CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ // {
using namespace ck_tile; // using namespace ck_tile;
using BDataType = remove_cvref_t<typename Problem::BDataType>; // using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; // constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; // constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( // constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}), // make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{}); // number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType); // constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( // constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3, // b_lds_block_desc_d1_d2_d3,
make_tuple( // make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1), // make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)), // make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}), // make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{})); // make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( // constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6, // b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})), // make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)), // make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), // make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); // make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k; // return b_lds_block_desc_n_k;
} // }
#endif #endif
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment