Commit 5a2d93d4 authored by coderfeli's avatar coderfeli
Browse files

revert code

parent 6a07464b
...@@ -46,6 +46,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_, ...@@ -46,6 +46,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_, template <typename DistributedTensor_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
......
...@@ -453,6 +453,58 @@ struct tile_window_linear ...@@ -453,6 +453,58 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <typename DistributedTensor, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
......
...@@ -210,7 +210,7 @@ struct BlockGemmARegBRegCRegV2 ...@@ -210,7 +210,7 @@ struct BlockGemmARegBRegCRegV2
auto tileDist = BlockTensor::get_tile_distribution(); auto tileDist = BlockTensor::get_tile_distribution();
return load_tile(block_tensor, make_tile_window(block_window, tileDist)); return load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile_raw(block_tensor, make_tile_window_linear_raw(block_window, tileDist)); // load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
// return; // return;
} }
......
...@@ -256,8 +256,40 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -256,8 +256,40 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{})); using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{}));
ALdsTile a_block_tile0; ALdsTile a_block_tile0;
BLdsTile b_block_tile0; BLdsTile b_block_tile0;
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
...@@ -274,8 +306,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -274,8 +306,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
...@@ -286,8 +318,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -286,8 +318,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// pong // pong
{ {
block_sync_lds(); block_sync_lds();
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
...@@ -303,8 +335,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -303,8 +335,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3 // 3
{ {
block_sync_lds(); block_sync_lds();
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
...@@ -312,8 +344,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -312,8 +344,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2 // 2
{ {
block_sync_lds(); block_sync_lds();
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
//1 //1
...@@ -324,8 +356,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -324,8 +356,8 @@ struct GemmPipelineAGmemBGmemCRegV1
} else { } else {
{ {
block_sync_lds(); block_sync_lds();
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{})); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
// 2 // 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment