Commit 730c5fff authored by coderfeli's avatar coderfeli
Browse files

fix linear

parent 4525c5d7
...@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1> ...@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1>
// //
printf("up_lengths_:"); printf("up_lengths_:");
print(up_lengths_); printx(up_lengths_);
// //
printf("}"); printf("}");
...@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1> ...@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("left_pad_length_: "); printf("left_pad_length_: ");
print(left_pad_length_); printx(left_pad_length_);
printf(", "); printf(", ");
// //
printf("right_pad_length_: "); printf("right_pad_length_: ");
print(right_pad_length_); printx(right_pad_length_);
printf("}"); printf("}");
} }
...@@ -337,12 +337,12 @@ struct left_pad ...@@ -337,12 +337,12 @@ struct left_pad
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("left_pad_length_: "); printf("left_pad_length_: ");
print(left_pad_length_); printx(left_pad_length_);
printf("}"); printf("}");
} }
...@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1> ...@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("right_pad_length_: "); printf("right_pad_length_: ");
print(right_pad_length_); printx(right_pad_length_);
printf("}"); printf("}");
} }
...@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()> ...@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("coefficients_: "); printf("coefficients_: ");
print(coefficients_); printx(coefficients_);
printf("}"); printf("}");
} }
...@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1> ...@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
// //
printf("low_lengths_ "); printf("low_lengths_ ");
print(low_lengths_); printx(low_lengths_);
printf(", "); printf(", ");
// //
printf("up_lengths_ "); printf("up_lengths_ ");
print(up_lengths_); printx(up_lengths_);
printf("}"); printf("}");
} }
...@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1> ...@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
// //
printf("low_lengths_ "); printf("low_lengths_ ");
print(low_lengths_); printx(low_lengths_);
printf(", "); printf(", ");
// //
printf("low_lengths_scan_ "); printf("low_lengths_scan_ ");
print(low_lengths_scan_); printx(low_lengths_scan_);
printf(", "); printf(", ");
// //
printf("up_lengths_ "); printf("up_lengths_ ");
print(up_lengths_); printx(up_lengths_);
printf("}"); printf("}");
} }
...@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()> ...@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()>
// //
printf("up_lengths_"); printf("up_lengths_");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("up_lengths_scan_"); printf("up_lengths_scan_");
print(up_lengths_scan_); printx(up_lengths_scan_);
printf("}"); printf("}");
} }
...@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0> ...@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0>
// //
printf("low_idx_: "); printf("low_idx_: ");
print(low_idx_); printx(low_idx_);
printf("}"); printf("}");
} }
...@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1> ...@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1>
printf("insert{"); printf("insert{");
// //
print(up_lengths_); printx(up_lengths_);
printf("}"); printf("}");
} }
...@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()> ...@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf("}"); printf("}");
} }
...@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1> ...@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("slice_begin_: "); printf("slice_begin_: ");
print(slice_begin_); printx(slice_begin_);
printf(", "); printf(", ");
// //
printf("slice_end_: "); printf("slice_end_: ");
print(slice_end_); printx(slice_end_);
printf("}"); printf("}");
} // namespace ck } // namespace ck
...@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1> ...@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf("}"); printf("}");
} }
...@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2> ...@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
printf("}"); printf("}");
...@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1> ...@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
// //
printf("offset_length_: "); printf("offset_length_: ");
print(offset_length_); printx(offset_length_);
printf("}"); printf("}");
} }
...@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1> ...@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1>
// //
printf("up_lengths_: "); printf("up_lengths_: ");
print(up_lengths_); printx(up_lengths_);
printf(", "); printf(", ");
printf("}"); printf("}");
......
...@@ -230,3 +230,6 @@ ...@@ -230,3 +230,6 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 #define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif #endif
template<typename T>
CK_TILE_HOST_DEVICE void printx(T a = {}) {a.print();}
\ No newline at end of file
...@@ -52,7 +52,13 @@ struct array ...@@ -52,7 +52,13 @@ struct array
data[i] = vlast; data[i] = vlast;
} }
} }
CK_TILE_HOST_DEVICE void print() const {
printf("array{size: %d, data: ", size());
for (index_t i = 0; i < size(); i++) {
printf("%d,", int(get(i)));
}
}
template <typename Y, template <typename Y,
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> || typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
std::is_constructible_v<Y, value_type>>> std::is_constructible_v<Y, value_type>>>
......
...@@ -195,6 +195,11 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...> ...@@ -195,6 +195,11 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>; using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() = default; CK_TILE_HOST_DEVICE constexpr tuple() = default;
CK_TILE_HOST_DEVICE void print() const {
// printf("tuple{size: %d, data: [", size());
// ((printf("%d ", Is)), ...);
// printf("]}");
}
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST #if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template <typename U> template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us) CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
......
...@@ -50,22 +50,6 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_, ...@@ -50,22 +50,6 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_, template <typename DistributedTensor_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
......
...@@ -89,6 +89,45 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& ...@@ -89,6 +89,45 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor&
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden}; remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
} }
// template <typename Adaptor, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate_debug(const Adaptor& adaptor,
// const TopIndex& idx_top)
// {
// static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
// "wrong! # of dimension inconsistent");
// constexpr index_t ntransform = Adaptor::get_num_of_transform();
// constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
// constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
// constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
// multi_index<ndim_hidden> idx_hidden;
// // idx_hidden.print();
// // initialize visible index
// set_container_subset(idx_hidden, top_dim_ids, idx_top);
// // calculate hidden index
// static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
// auto itran = itran_p1 - number<1>{};
// const auto& tran = adaptor.get_transforms().at(itran);
// tran.print();
// constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
// constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
// const auto idx_up = get_container_subset(idx_hidden, dims_up);
// multi_index<dims_low.size()> idx_low;
// tran.calculate_lower_index(idx_low, idx_up);
// set_container_subset(idx_hidden, dims_low, idx_low);
// idx_hidden.print();
// });
// return tensor_adaptor_coordinate<ndim_hidden,
// remove_cvref_t<decltype(bottom_dim_ids)>,
// remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
// }
template <bool JudgeDoTransforms = true, template <bool JudgeDoTransforms = true,
typename Adaptor, typename Adaptor,
typename AdaptorCoord, typename AdaptorCoord,
......
...@@ -66,6 +66,16 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens ...@@ -66,6 +66,16 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{ remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
adaptor_coord}; adaptor_coord};
} }
// template <typename TensorDesc, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate_debug(const TensorDesc& tensor_desc,
// const TopIndex& idx_top)
// {
// const auto adaptor_coord = make_tensor_adaptor_coordinate_debug(tensor_desc, idx_top);
// return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
// remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
// adaptor_coord};
// }
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index> template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
CK_TILE_HOST_DEVICE constexpr void CK_TILE_HOST_DEVICE constexpr void
......
...@@ -440,6 +440,13 @@ struct tile_window_linear ...@@ -440,6 +440,13 @@ struct tile_window_linear
// we directly use BottomTensorView transform to compute the offset, in case padding // we directly use BottomTensorView transform to compute the offset, in case padding
auto bottom_tensor_coord = auto bottom_tensor_coord =
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord); make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// if(threadIdx.x == 0) {
// bottom_tensor_coord =
// make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// printf("off00 %d %d\n",i_access, bottom_tensor_coord.get_offset() );
// bottom_tensor_coord.get_hidden_index().print();
// bottom_tensor_coord.get_index().print();
// }
return bottom_tensor_coord.get_offset(); return bottom_tensor_coord.get_offset();
} }
else else
...@@ -468,14 +475,16 @@ struct tile_window_linear ...@@ -468,14 +475,16 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <typename DistributedTensor, index_t i_access = -1, bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{}; constexpr auto IAccess = number<i_access_>{};
...@@ -518,13 +527,7 @@ struct tile_window_linear ...@@ -518,13 +527,7 @@ struct tile_window_linear
}; };
WINDOW_DISPATCH_ISSUE(); WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
auto dst_tensor = make_static_distributed_tensor<DataType>(TileDstr{});
load(dst_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
return dst_tensor; return dst_tensor;
} }
...@@ -547,8 +550,7 @@ struct tile_window_linear ...@@ -547,8 +550,7 @@ struct tile_window_linear
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor // read from bottom tensor
const vector_t vec_value = const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>( get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
......
...@@ -232,8 +232,8 @@ struct BlockGemmARegBRegCRegV2 ...@@ -232,8 +232,8 @@ struct BlockGemmARegBRegCRegV2
CK_TILE_DEVICE static void PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor) CK_TILE_DEVICE static void PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{ {
auto tileDist = BlockTensor::get_tile_distribution(); auto tileDist = BlockTensor::get_tile_distribution();
load_tile(block_tensor, make_tile_window(block_window, tileDist)); // load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile(block_tensor, make_tile_window_linear(block_window, tileDist)); load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
} }
// C = A * B // C = A * B
......
...@@ -260,12 +260,13 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -260,12 +260,13 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0; ALdsTile a_block_tile0;
BLdsTile b_block_tile0; BLdsTile b_block_tile0;
auto a_lds_ld_window0 = make_tile_window_linear(a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr); auto a_lds_ld_window0 = make_tile_window_linear(a_lds_window0, ALdsTileDistr);
auto a_lds_ld_window1 = make_tile_window_linear(a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr); auto a_lds_ld_window1 = make_tile_window_linear(a_lds_window1, ALdsTileDistr);
auto b_lds_ld_window0 = make_tile_window_linear(b_lds_window0, BLdsTileDistr);
auto b_lds_ld_window1 = make_tile_window_linear(b_lds_window1, BLdsTileDistr);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0); load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); load_tile(b_block_tile0, b_lds_ld_window0);
// LDS write 1 // LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
...@@ -285,9 +286,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -285,9 +286,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1); load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); load_tile(b_block_tile1, b_lds_ld_window1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
...@@ -300,7 +300,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -300,7 +300,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0); load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); load_tile(b_block_tile0, b_lds_ld_window0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
...@@ -319,7 +319,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -319,7 +319,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1); load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); load_tile(b_block_tile1, b_lds_ld_window1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
...@@ -329,7 +329,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -329,7 +329,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0); load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); load_tile(b_block_tile0, b_lds_ld_window0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
//1 //1
...@@ -344,7 +344,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -344,7 +344,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds(); block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1); load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); load_tile(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
// 2 // 2
......
...@@ -59,14 +59,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -59,14 +59,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
// TODO: this 8 is AK1! should be a policy parameter! // TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<kMPerBlock * 8>{}, number<8>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor( constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0, a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock), make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))), make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -88,8 +88,10 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -88,8 +88,10 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto b_lds_block_desc = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0, b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock), // make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))), // make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment