"...resnet50_tensorflow.git" did not exist on "b5286d69525724c6ae2c79d246e0b65c7a1ba37f"
Commit 60356c90 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Enable splitkv async pipeline

parent de6dd79f
...@@ -626,8 +626,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -626,8 +626,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: if hdim == 256:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
......
...@@ -377,6 +377,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -377,6 +377,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
number<-1>{}, number<-1>{},
k_oob_ck, k_oob_ck,
k_pre_np); k_pre_np);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
...@@ -697,13 +699,17 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -697,13 +699,17 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}); });
} }
i_total_loops++; i_total_loops++;
// load the first K tile for next iteration
if(i_total_loops < num_total_loop) if(i_total_loops < num_total_loop)
{ {
// move K tile windows // move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window( i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0}); i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); k_dram_window = make_tile_window(
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.init_raw();
if constexpr(k1_loops >= 2 && if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{})) LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment