Commit 97efebdb authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Special treatment for hdim-96 to save vgprs in qr_ks_vs_async pipeline

parent a94ac4bb
...@@ -76,25 +76,27 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -76,25 +76,27 @@ struct BlockFmhaPipelineQRKSVSAsync
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kQKHeaddim <= 32) if constexpr(kQKHeaddim == 32)
{ {
return 2; return 2;
} }
else if constexpr(kQKHeaddim <= 64) else if constexpr(kQKHeaddim == 64)
{ {
return 2; return 2;
} }
else if constexpr(kQKHeaddim <= 128) else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 1; return 2;
} }
else if constexpr(kQKHeaddim <= 256) else if constexpr(kQKHeaddim == 256)
{ {
return 1; return 1;
} }
else
return 1;
} }
}(); }();
...@@ -170,7 +172,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -170,7 +172,6 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
static_assert(NumKLdsBuffers >= 2); static_assert(NumKLdsBuffers >= 2);
static_assert(NumVLdsBuffers >= 2);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
...@@ -269,7 +270,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -269,7 +270,13 @@ struct BlockFmhaPipelineQRKSVSAsync
using k_tile_type = decltype(load_tile(k_dram_window)); using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k0_loops> k_tiles; auto k_tiles = [&]() {
// for hdim-96 and hdim-160, try to save vgprs
if constexpr(kQKHeaddim < kSubQKHeaddim)
return statically_indexed_array<k_tile_type, 2>{};
else
return statically_indexed_array<k_tile_type, k0_loops>{};
}();
k_tiles[I0] = load_tile(k_dram_window); k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
...@@ -296,121 +303,158 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -296,121 +303,158 @@ struct BlockFmhaPipelineQRKSVSAsync
do do
{ {
if(i_total_loops == 0) // executed by fist iteration if constexpr(kQKHeaddim == kSubQKHeaddim)
{ {
if(num_total_loop > 1) // there are multiple iterations if(i_total_loops == 0) // executed by fist iteration
{ {
auto k_lds_window_tmp = if(num_total_loop > 1) // there are multiple iterations
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}); {
store_tile(k_lds_window_tmp, k_tiles[I0]); auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window); k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
k_lds_window_tmp = get_slice_tile( k_lds_window_tmp = get_slice_tile(
k_lds_window, k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{}, sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{}); sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]); store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
}); });
move_tile_window(k_dram_window, {kN0, -k0_loops * kK0}); move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[number<i_k0>{}] = load_tile(k_dram_window); k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
}); });
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
block_sync_lds(); block_sync_lds();
// execute last unroll of gemm_0 // execute last unroll of gemm_0
gemm_0(s_acc, q_tiles[number<k0_loops - 1>{}], k_lds_window_tmp); gemm_0(s_acc, q_tiles[number<k0_loops - 1>{}], k_lds_window_tmp);
}
else // there is only single iteration
{
auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
clear_tile(s_acc); // initialize C
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
if constexpr(i_k0 < k0_loops - 1)
{
k_lds_window_tmp = get_slice_tile(
k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
};
});
// move_tile_window(k_dram_window, {0, -k0_loops * kK0});
}
} }
else // there is only single iteration else // executed by intermediate and last iteration
{ {
auto k_lds_window_tmp = if(i_total_loops < num_total_loop - 1) // intermediate iteration
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}); {
store_tile(k_lds_window_tmp, k_tiles[I0]); move_tile_window(k_dram_window, {kN0, 0});
clear_tile(s_acc); // initialize C static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp = get_slice_tile(
k_lds_window,
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
static_for<0, k0_loops, 1>{}([&](auto i_k0) { k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
{ move_tile_window(k_dram_window, {0, kK0});
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
block_sync_lds(); if constexpr(i_k0 == 0)
// execute current unroll of gemm_0 clear_tile(s_acc);
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
if constexpr(i_k0 < k0_loops - 1) block_sync_lds();
{ gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
k_lds_window_tmp = get_slice_tile( });
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
}
else // last iteration
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp = get_slice_tile(
k_lds_window, k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{}, sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{}); sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]); store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
};
});
// move_tile_window(k_dram_window, {0, -k0_loops * kK0}); if constexpr(i_k0 == 0)
} clear_tile(s_acc);
block_sync_lds();
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
});
};
};
} }
else // executed by intermediate and last iteration else
{ {
if(i_total_loops < num_total_loop - 1) // intermediate iteration auto k_lds_window_tmp =
{ get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
move_tile_window(k_dram_window, {kN0, 0}); store_tile(k_lds_window_tmp, k_tiles[I0]);
static_for<0, k0_loops, 1>{}([&](auto i_k0) { clear_tile(s_acc); // initialize C
auto k_lds_window_tmp =
get_slice_tile(k_lds_window,
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
k_tiles[number<i_k0>{}] = load_tile(k_dram_window); static_for<0, k0_loops, 1>{}([&](auto i_k0) {
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); {
k_tiles[number<(i_k0 + 1) % 2>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
if constexpr(i_k0 == 0) block_sync_lds();
clear_tile(s_acc); // execute current unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
block_sync_lds(); if constexpr(i_k0 < k0_loops - 1)
// execute last unroll of gemm_0 {
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); k_lds_window_tmp = get_slice_tile(
}); k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<(i_k0 + 1) % 2>{}]);
};
});
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); if(i_total_loops < num_total_loop - 1)
}
else // last iteration
{ {
static_for<0, k0_loops, 1>{}([&](auto i_k0) { move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
auto k_lds_window_tmp = k_tiles[I0] = load_tile(k_dram_window);
get_slice_tile(k_lds_window, move_tile_window(k_dram_window, {0, kK0});
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
if constexpr(i_k0 == 0)
clear_tile(s_acc);
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
});
}; };
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment