"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "e4e1c76f921471f302dd31d677361668fc08ac78"
Commit da59d3b2 authored by coderfeli's avatar coderfeli
Browse files

remove useless comments and changes

parent 6fd51c43
...@@ -15,7 +15,7 @@ using F16 = ck::half_t; ...@@ -15,7 +15,7 @@ using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -32,17 +32,15 @@ using DeviceGemmInstance = ...@@ -32,17 +32,15 @@ using DeviceGemmInstance =
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
2, 256, 2, 256,
256, 256, 256, 256,
32, 8, 8, 32, 8, 4,
32, 32, 32, 32,
4, 4, 4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
2, 8, 8, 0, 1, 8, 4, 0,
1, 1, S<1, 32, 1, 8>, 8, 1, 1, S<1, 32, 1, 8>, 8,
ck::LoopScheduler::Default, ck::PipelineVersion::v1>; ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
//./bin/example_gemm_xdl_fp16_v2 0 0 1 5120 5120 8320 8320 8320 5120
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -22,12 +22,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -22,12 +22,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// This part comes from the Codegen // This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -40,6 +43,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -40,6 +43,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Whether doing the CShuffle (transpose before the global memory), depending on the output // Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout. // layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
...@@ -47,21 +52,27 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -47,21 +52,27 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
// constexpr ck_tile::index_t Warp_Size = 64;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType, using GemmEpilogue = std::conditional_t<
// CDataType, CShuffleEpilogue,
// M_Warp * N_Warp * K_Warp * Warp_Size, ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
// 64, CDataType,
// TilePartitioner::kN, kPadM,
// kPadM, kPadN,
// kPadN>>; kTilePermute,
using GemmEpilogue = ck_tile::Default2DEpilogue< kOutputRank,
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; 1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 2>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy; using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
...@@ -81,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -81,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -119,9 +119,12 @@ int run_gemm_example_with_layouts(int argc, ...@@ -119,9 +119,12 @@ int run_gemm_example_with_layouts(int argc,
} else if (init_method == 1) { } else if (init_method == 1) {
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k); ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n); ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else { } else if (init_method == 2) {
ck_tile::FillConstant<ADataType>{1.f}(a_m_k); ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
ck_tile::FillConstant<BDataType>{1.f}(b_k_n); ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
} else {
a_m_k.SetZero();
b_k_n.SetZero();
} }
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
......
...@@ -46,17 +46,6 @@ template <typename ALayout, ...@@ -46,17 +46,6 @@ template <typename ALayout,
index_t NPerXDL, index_t NPerXDL,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
// 2, 256,
// 256, 256,
// 32, 8, 8,
// 32, 32,
// 4, 4,
// S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// 1, 1, S<1, 32, 1, 8>, 8,
// ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
......
...@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1> ...@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1>
// //
printf("up_lengths_:"); printf("up_lengths_:");
printx(up_lengths_); print(up_lengths_);
// //
printf("}"); printf("}");
...@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1> ...@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("left_pad_length_: "); printf("left_pad_length_: ");
printx(left_pad_length_); print(left_pad_length_);
printf(", "); printf(", ");
// //
printf("right_pad_length_: "); printf("right_pad_length_: ");
printx(right_pad_length_); print(right_pad_length_);
printf("}"); printf("}");
} }
...@@ -337,12 +337,12 @@ struct left_pad ...@@ -337,12 +337,12 @@ struct left_pad
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("left_pad_length_: "); printf("left_pad_length_: ");
printx(left_pad_length_); print(left_pad_length_);
printf("}"); printf("}");
} }
...@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1> ...@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("right_pad_length_: "); printf("right_pad_length_: ");
printx(right_pad_length_); print(right_pad_length_);
printf("}"); printf("}");
} }
...@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()> ...@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("coefficients_: "); printf("coefficients_: ");
printx(coefficients_); print(coefficients_);
printf("}"); printf("}");
} }
...@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1> ...@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
// //
printf("low_lengths_ "); printf("low_lengths_ ");
printx(low_lengths_); print(low_lengths_);
printf(", "); printf(", ");
// //
printf("up_lengths_ "); printf("up_lengths_ ");
printx(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
...@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1> ...@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
// //
printf("low_lengths_ "); printf("low_lengths_ ");
printx(low_lengths_); print(low_lengths_);
printf(", "); printf(", ");
// //
printf("low_lengths_scan_ "); printf("low_lengths_scan_ ");
printx(low_lengths_scan_); print(low_lengths_scan_);
printf(", "); printf(", ");
// //
printf("up_lengths_ "); printf("up_lengths_ ");
printx(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
...@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()> ...@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()>
// //
printf("up_lengths_"); printf("up_lengths_");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("up_lengths_scan_"); printf("up_lengths_scan_");
printx(up_lengths_scan_); print(up_lengths_scan_);
printf("}"); printf("}");
} }
...@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0> ...@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0>
// //
printf("low_idx_: "); printf("low_idx_: ");
printx(low_idx_); print(low_idx_);
printf("}"); printf("}");
} }
...@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1> ...@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1>
printf("insert{"); printf("insert{");
// //
printx(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
...@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()> ...@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
...@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1> ...@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("slice_begin_: "); printf("slice_begin_: ");
printx(slice_begin_); print(slice_begin_);
printf(", "); printf(", ");
// //
printf("slice_end_: "); printf("slice_end_: ");
printx(slice_end_); print(slice_end_);
printf("}"); printf("}");
} // namespace ck } // namespace ck
...@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1> ...@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
...@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2> ...@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
printf("}"); printf("}");
...@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1> ...@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
// //
printf("offset_length_: "); printf("offset_length_: ");
printx(offset_length_); print(offset_length_);
printf("}"); printf("}");
} }
...@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1> ...@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
printx(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
printf("}"); printf("}");
......
...@@ -195,11 +195,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...> ...@@ -195,11 +195,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>; using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() = default; CK_TILE_HOST_DEVICE constexpr tuple() = default;
CK_TILE_HOST_DEVICE void print() const {
// printf("tuple{size: %d, data: [", size());
// ((printf("%d ", Is)), ...);
// printf("]}");
}
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST #if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template <typename U> template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us) CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
......
...@@ -201,17 +201,6 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number ...@@ -201,17 +201,6 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks; return unpacks;
} }
template <typename StaticTensor>
CK_TILE_DEVICE void dump_static_tensor(StaticTensor& t){
constexpr auto span_2d = decltype(t)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
printf("%f,", type_convert<float>(t(i_j_idx)));
});
printf("\n");
});
}
namespace detail { namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element // check if 2 static_distributed_tensor has same data type and size of element
......
...@@ -89,45 +89,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& ...@@ -89,45 +89,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor&
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden}; remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
} }
// template <typename Adaptor, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate_debug(const Adaptor& adaptor,
// const TopIndex& idx_top)
// {
// static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
// "wrong! # of dimension inconsistent");
// constexpr index_t ntransform = Adaptor::get_num_of_transform();
// constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
// constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
// constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
// multi_index<ndim_hidden> idx_hidden;
// // idx_hidden.print();
// // initialize visible index
// set_container_subset(idx_hidden, top_dim_ids, idx_top);
// // calculate hidden index
// static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
// auto itran = itran_p1 - number<1>{};
// const auto& tran = adaptor.get_transforms().at(itran);
// tran.print();
// constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
// constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
// const auto idx_up = get_container_subset(idx_hidden, dims_up);
// multi_index<dims_low.size()> idx_low;
// tran.calculate_lower_index(idx_low, idx_up);
// set_container_subset(idx_hidden, dims_low, idx_low);
// idx_hidden.print();
// });
// return tensor_adaptor_coordinate<ndim_hidden,
// remove_cvref_t<decltype(bottom_dim_ids)>,
// remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
// }
template <bool JudgeDoTransforms = true, template <bool JudgeDoTransforms = true,
typename Adaptor, typename Adaptor,
typename AdaptorCoord, typename AdaptorCoord,
......
...@@ -66,16 +66,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens ...@@ -66,16 +66,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{ remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
adaptor_coord}; adaptor_coord};
} }
// template <typename TensorDesc, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate_debug(const TensorDesc& tensor_desc,
// const TopIndex& idx_top)
// {
// const auto adaptor_coord = make_tensor_adaptor_coordinate_debug(tensor_desc, idx_top);
// return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
// remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
// adaptor_coord};
// }
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index> template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
CK_TILE_HOST_DEVICE constexpr void CK_TILE_HOST_DEVICE constexpr void
......
...@@ -440,13 +440,6 @@ struct tile_window_linear ...@@ -440,13 +440,6 @@ struct tile_window_linear
// we directly use BottomTensorView transform to compute the offset, in case padding // we directly use BottomTensorView transform to compute the offset, in case padding
auto bottom_tensor_coord = auto bottom_tensor_coord =
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord); make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// if(threadIdx.x == 0) {
// bottom_tensor_coord =
// make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// printf("off00 %d %d\n",i_access, bottom_tensor_coord.get_offset() );
// bottom_tensor_coord.get_hidden_index().print();
// bottom_tensor_coord.get_index().print();
// }
return bottom_tensor_coord.get_offset(); return bottom_tensor_coord.get_offset();
} }
else else
...@@ -550,7 +543,8 @@ struct tile_window_linear ...@@ -550,7 +543,8 @@ struct tile_window_linear
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; auto bottom_tensor_flag = cached_flags_[IAccess];
auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor // read from bottom tensor
const vector_t vec_value = const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>( get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
......
...@@ -53,82 +53,99 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -53,82 +53,99 @@ struct BlockGemmASmemBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
// constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
// const index_t iMWarp = get_warp_id() / NWarp; const index_t iMWarp = get_warp_id() / NWarp;
// const index_t iNWarp = get_warp_id() % NWarp; const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x == 0 && blockIdx.x==0) { // construct A-warp-window
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
// }
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(MPerBlock, KPerBlock), make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{0, 0}, a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
Policy::template MakeALDSTileDistribution<Problem>()); make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(NPerBlock, KPerBlock), make_tuple(number<WG::kN>{}, number<WG::kK>{}),
{0, 0}, b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
Policy::template MakeBLDSTileDistribution<Problem>()); make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
auto a_block_tensor = load_tile(a_warp_window_tmp); #if 0 // FIXME: using array will cause register spill
auto b_block_tensor = load_tile(b_warp_window_tmp); array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
// if (threadIdx.x == 0) { for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// printf("0\n"); {
// constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans(); for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { move_tile_window(b_warp_windows(nIter)(kIter),
// constexpr auto i_j_idx = make_tuple(idx0, idx1); {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx))); }
// }); }
// printf("\n"); #else
// }); statically_indexed_array<
// } statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// __syncthreads(); NIterPerWarp>
using AWarpDstr = typename WG::AWarpDstr; b_warp_windows;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
using AWarpTensor = typename WG::AWarpTensor; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
using BWarpTensor = typename WG::BWarpTensor;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor; using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop: // hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window // read A warp tensor from A block window
AWarpTensor a_warp_tensor; const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor // read B warp tensor from B Block window
BWarpTensor b_warp_tensor; const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor // read C warp tensor from C block tensor
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -192,72 +209,5 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -192,72 +209,5 @@ struct BlockGemmASmemBSmemCRegV1
return c_block_tensor; return c_block_tensor;
} }
}; };
// construct A-warp-window
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
// {a_warp_window_tmp}};
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
// construct B-warp-window
// auto b_warp_window_tmp = make_tile_window(
// b_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
// {b_warp_window_tmp}};
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
} // namespace ck_tile } // namespace ck_tile
...@@ -40,8 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -40,8 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
} }
#else #else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2); return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif #endif
} }
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
...@@ -55,96 +54,6 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -55,96 +54,6 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
static_assert(false, "Unsupported data type configuration for GEMM warp execution."); static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALDSTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(ADataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t M2 = 32; // MPERXDL
constexpr index_t M1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLDSTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_assert(false, "Unsupported tensor_layout right now.");
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr index_t K2 = 16 / sizeof(BDataType);
constexpr index_t K1 = 2;
constexpr index_t K0 = KPerBlock / K1 / K2;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr index_t N2 = 32; // MPERXDL
constexpr index_t N1 = 2; //MWAVE
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(false, "Unsupported shape right now.");
}
}
}
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment