Commit e511bb78 authored by coderfeli's avatar coderfeli
Browse files

lds a,b ok

parent d51f4e52
......@@ -503,10 +503,6 @@ include_directories(BEFORE
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
......@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
# -Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......
......@@ -117,6 +117,10 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
// ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
// ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
......
......@@ -211,15 +211,16 @@ struct FillNormalDistributionIntegerValue
template <typename T>
struct FillMonotonicSeq
{
T init_value_{0};
T init_value_{-1024};
T step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = init_value_]() mutable {
T step_start = init_value_;
std::generate(first, last, [&, n = init_value_]() mutable {
auto tmp = n;
n += step_;
if (n > step_start + 2047) {step_start += step_; n = step_start;}
return tmp;
});
}
......
......@@ -42,9 +42,6 @@ struct BlockGemmASmemBSmemCRegV1
KPerBlock == BlockGemmShape::kK,
"wrong!");
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock);
// }
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
......@@ -56,12 +53,12 @@ struct BlockGemmASmemBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
// constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
// constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// const index_t iMWarp = get_warp_id() / NWarp;
// const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
......@@ -69,91 +66,69 @@ struct BlockGemmASmemBSmemCRegV1
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
make_tuple(MPerBlock, KPerBlock),
{0, 0},
Policy::template MakeALDSTileDistribution<Problem>());
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
make_tuple(NPerBlock, KPerBlock),
{0, 0},
Policy::template MakeBLDSTileDistribution<Problem>());
auto a_block_tensor = load_tile(a_warp_window_tmp);
auto b_block_tensor = load_tile(b_warp_window_tmp);
// if (threadIdx.x == 0) {
// printf("0\n");
// constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
// });
// printf("\n");
// });
// }
// __syncthreads();
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
......@@ -173,6 +148,36 @@ struct BlockGemmASmemBSmemCRegV1
});
});
});
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// // hot loop:
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// // read A warp tensor from A block window
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// // read B warp tensor from B Block window
// // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// // read C warp tensor from C block tensor
// CWarpTensor c_warp_tensor;
// c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// // warp GEMM
// WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter));
// // write C warp tensor into C block tensor
// c_block_tensor.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensor.get_thread_buffer());
// });
// });
// });
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
......@@ -217,5 +222,72 @@ struct BlockGemmASmemBSmemCRegV1
return c_block_tensor;
}
};
// construct A-warp-window
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
// {a_warp_window_tmp}};
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
// construct B-warp-window
// auto b_warp_window_tmp = make_tile_window(
// b_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
// {b_warp_window_tmp}};
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
} // namespace ck_tile
......@@ -40,7 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 2, 2);
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
......@@ -55,6 +55,96 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALDSTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(ADataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t M2 = 32; // MPERXDL
constexpr index_t M1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLDSTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(BDataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t N2 = 32; // MPERXDL
constexpr index_t N1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
};
} // namespace ck_tile
......@@ -133,7 +133,16 @@ struct GemmPipelineAGmemBGmemCRegV1
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
......@@ -170,7 +179,17 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
}
// __syncthreads();
// if (threadIdx.x == 0) {
// for (int j = 0; j < 256; j++) {
// for(int i = 0; i < 32; i++) {
// int ik0 = i /8;
// int ik1 = i % 8;
// printf("%f,", type_convert<float>(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256]));
// }
// printf("\n");
// }
// }
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
......@@ -219,6 +238,17 @@ struct GemmPipelineAGmemBGmemCRegV1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
return c_block_tile;
}
......
......@@ -54,7 +54,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
......@@ -77,7 +77,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
......@@ -130,74 +130,74 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
#elif 1
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc_m_k;
}
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k;
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
// {
// using namespace ck_tile;
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
// constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
// make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
// number<kKPerBlock>{});
// constexpr index_t kK1 = 16 / sizeof(ADataType);
// constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
// a_lds_block_desc_d1_d2_d3,
// make_tuple(
// make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
// make_pass_through_transform(2)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0, 2>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
// a_lds_block_desc_d4_d5_d6,
// make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
// make_pass_through_transform(kKPerBlock)),
// make_tuple(sequence<0, 1>{}, sequence<2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_m_k;
// }
// // fake XOR
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
// {
// using namespace ck_tile;
// using BDataType = remove_cvref_t<typename Problem::BDataType>;
// constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
// make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
// number<kKPerBlock>{});
// constexpr index_t kK1 = 16 / sizeof(BDataType);
// constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
// b_lds_block_desc_d1_d2_d3,
// make_tuple(
// make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
// make_pass_through_transform(2)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0, 2>{}, sequence<1>{}));
// constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
// b_lds_block_desc_d4_d5_d6,
// make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
// make_pass_through_transform(kKPerBlock)),
// make_tuple(sequence<0, 1>{}, sequence<2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc_n_k;
// }
#endif
template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment