Commit da59d3b2 authored by coderfeli's avatar coderfeli
Browse files

remove useless comments and changes

parent 6fd51c43
......@@ -15,7 +15,7 @@ using F16 = ck::half_t;
using F32 = float;
using ALayout = Row;
using BLayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
......@@ -32,17 +32,15 @@ using DeviceGemmInstance =
PassThrough, PassThrough, PassThrough, GemmDefault,
2, 256,
256, 256,
32, 8, 8,
32, 8, 4,
32, 32,
4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 4, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
//./bin/example_gemm_xdl_fp16_v2 0 0 1 5120 5120 8320 8320 8320 5120
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -22,12 +22,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
......@@ -40,6 +43,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
......@@ -47,21 +52,27 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
// constexpr ck_tile::index_t Warp_Size = 64;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
// CDataType,
// M_Warp * N_Warp * K_Warp * Warp_Size,
// 64,
// TilePartitioner::kN,
// kPadM,
// kPadN>>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 2>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
......@@ -81,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......
......@@ -119,9 +119,12 @@ int run_gemm_example_with_layouts(int argc,
} else if (init_method == 1) {
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else {
} else if (init_method == 2) {
ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
} else {
a_m_k.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
......
......@@ -46,17 +46,6 @@ template <typename ALayout,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
// 2, 256,
// 256, 256,
// 32, 8, 8,
// 32, 32,
// 4, 4,
// S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// 1, 1, S<1, 32, 1, 8>, 8,
// ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......
......@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1>
//
printf("up_lengths_:");
printx(up_lengths_);
print(up_lengths_);
//
printf("}");
......@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
printx(left_pad_length_);
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
printx(right_pad_length_);
print(right_pad_length_);
printf("}");
}
......@@ -337,12 +337,12 @@ struct left_pad
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
printx(left_pad_length_);
print(left_pad_length_);
printf("}");
}
......@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
printx(right_pad_length_);
print(right_pad_length_);
printf("}");
}
......@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
printx(coefficients_);
print(coefficients_);
printf("}");
}
......@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
//
printf("low_lengths_ ");
printx(low_lengths_);
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
printx(up_lengths_);
print(up_lengths_);
printf("}");
}
......@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
//
printf("low_lengths_ ");
printx(low_lengths_);
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
printx(low_lengths_scan_);
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
printx(up_lengths_);
print(up_lengths_);
printf("}");
}
......@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()>
//
printf("up_lengths_");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
printx(up_lengths_scan_);
print(up_lengths_scan_);
printf("}");
}
......@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0>
//
printf("low_idx_: ");
printx(low_idx_);
print(low_idx_);
printf("}");
}
......@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1>
printf("insert{");
//
printx(up_lengths_);
print(up_lengths_);
printf("}");
}
......@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf("}");
}
......@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
printx(slice_begin_);
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
printx(slice_end_);
print(slice_end_);
printf("}");
} // namespace ck
......@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf("}");
}
......@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
printf("}");
......@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
printx(offset_length_);
print(offset_length_);
printf("}");
}
......@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1>
//
printf("up_lengths_: ");
printx(up_lengths_);
print(up_lengths_);
printf(", ");
printf("}");
......
......@@ -195,11 +195,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() = default;
CK_TILE_HOST_DEVICE void print() const {
// printf("tuple{size: %d, data: [", size());
// ((printf("%d ", Is)), ...);
// printf("]}");
}
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
......
......@@ -201,17 +201,6 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks;
}
template <typename StaticTensor>
CK_TILE_DEVICE void dump_static_tensor(StaticTensor& t){
constexpr auto span_2d = decltype(t)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
printf("%f,", type_convert<float>(t(i_j_idx)));
});
printf("\n");
});
}
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
......
......@@ -89,45 +89,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor&
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
}
// template <typename Adaptor, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate_debug(const Adaptor& adaptor,
// const TopIndex& idx_top)
// {
// static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
// "wrong! # of dimension inconsistent");
// constexpr index_t ntransform = Adaptor::get_num_of_transform();
// constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
// constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
// constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
// multi_index<ndim_hidden> idx_hidden;
// // idx_hidden.print();
// // initialize visible index
// set_container_subset(idx_hidden, top_dim_ids, idx_top);
// // calculate hidden index
// static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
// auto itran = itran_p1 - number<1>{};
// const auto& tran = adaptor.get_transforms().at(itran);
// tran.print();
// constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
// constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
// const auto idx_up = get_container_subset(idx_hidden, dims_up);
// multi_index<dims_low.size()> idx_low;
// tran.calculate_lower_index(idx_low, idx_up);
// set_container_subset(idx_hidden, dims_low, idx_low);
// idx_hidden.print();
// });
// return tensor_adaptor_coordinate<ndim_hidden,
// remove_cvref_t<decltype(bottom_dim_ids)>,
// remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
// }
template <bool JudgeDoTransforms = true,
typename Adaptor,
typename AdaptorCoord,
......
......@@ -66,16 +66,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
adaptor_coord};
}
// template <typename TensorDesc, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate_debug(const TensorDesc& tensor_desc,
// const TopIndex& idx_top)
// {
// const auto adaptor_coord = make_tensor_adaptor_coordinate_debug(tensor_desc, idx_top);
// return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
// remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
// adaptor_coord};
// }
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
CK_TILE_HOST_DEVICE constexpr void
......
......@@ -440,13 +440,6 @@ struct tile_window_linear
// we directly use BottomTensorView transform to compute the offset, in case padding
auto bottom_tensor_coord =
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// if(threadIdx.x == 0) {
// bottom_tensor_coord =
// make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// printf("off00 %d %d\n",i_access, bottom_tensor_coord.get_offset() );
// bottom_tensor_coord.get_hidden_index().print();
// bottom_tensor_coord.get_index().print();
// }
return bottom_tensor_coord.get_offset();
}
else
......@@ -550,7 +543,8 @@ struct tile_window_linear
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
auto linear_offset = get_bottom_linear_offset(IAccess);
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
......
......@@ -53,82 +53,99 @@ struct BlockGemmASmemBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
// constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
// const index_t iMWarp = get_warp_id() / NWarp;
// const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
// }
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(MPerBlock, KPerBlock),
{0, 0},
Policy::template MakeALDSTileDistribution<Problem>());
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(NPerBlock, KPerBlock),
{0, 0},
Policy::template MakeBLDSTileDistribution<Problem>());
auto a_block_tensor = load_tile(a_warp_window_tmp);
auto b_block_tensor = load_tile(b_warp_window_tmp);
// if (threadIdx.x == 0) {
// printf("0\n");
// constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
// });
// printf("\n");
// });
// }
// __syncthreads();
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
......@@ -192,72 +209,5 @@ struct BlockGemmASmemBSmemCRegV1
return c_block_tensor;
}
};
// construct A-warp-window
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
// {a_warp_window_tmp}};
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
// construct B-warp-window
// auto b_warp_window_tmp = make_tile_window(
// b_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
// {b_warp_window_tmp}};
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
} // namespace ck_tile
......@@ -40,8 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
......@@ -55,96 +54,6 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALDSTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(ADataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t M2 = 32; // MPERXDL
constexpr index_t M1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLDSTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(BDataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t N2 = 32; // MPERXDL
constexpr index_t N1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment