Commit 5a2d93d4 authored by coderfeli's avatar coderfeli
Browse files

revert code

parent 6a07464b
......@@ -46,6 +46,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
......
......@@ -453,6 +453,58 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <typename DistributedTensor, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor dst_tensor, number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
......
......@@ -210,7 +210,7 @@ struct BlockGemmARegBRegCRegV2
auto tileDist = BlockTensor::get_tile_distribution();
return load_tile(block_tensor, make_tile_window(block_window, tileDist));
// load_tile_raw(block_tensor, make_tile_window_linear_raw(block_window, tileDist));
// load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
// return;
}
......
......@@ -256,8 +256,40 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{}));
ALdsTile a_block_tile0;
BLdsTile b_block_tile0;
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
......@@ -274,8 +306,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
{
block_sync_lds();
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
......@@ -286,8 +318,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// pong
{
block_sync_lds();
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
......@@ -303,8 +335,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds();
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
......@@ -312,8 +344,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
{
block_sync_lds();
load_tile(a_block_tile0, make_tile_window(a_lds_window0, ALdsTileDistr{}));
load_tile(b_block_tile0, make_tile_window(b_lds_window0, BLdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
//1
......@@ -324,8 +356,8 @@ struct GemmPipelineAGmemBGmemCRegV1
} else {
{
block_sync_lds();
load_tile(a_block_tile1, make_tile_window(a_lds_window1, ALdsTileDistr{}));
load_tile(b_block_tile1, make_tile_window(b_lds_window1, BLdsTileDistr{}));
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment