Commit 57baa79f authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

Merge remote-tracking branch 'origin/develop' into jakpiase/ck_tile_examples_package

parents ce8c840f a3757a5f
...@@ -255,20 +255,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -255,20 +255,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
} }
else else
{ {
// Tail number always Full - #PrefetchStages std::ostringstream err;
if(tail_num == ck_tile::TailNumber::Full) err << "Num K loop must be larger than number of prefetech stages."
{ << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__
Run(ck_tile::bool_constant<false>{}, << ":" << __LINE__ << ", in function: " << __func__;
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{}); throw std::runtime_error(err.str());
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
} }
return ave_time; return ave_time;
......
...@@ -310,7 +310,7 @@ struct SimplifiedGenericAttentionMask ...@@ -310,7 +310,7 @@ struct SimplifiedGenericAttentionMask
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; const index_t split_start = x_per_split * i_split;
const index_t split_end = split_start + x_per_split; const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end)); ck_tile::min(origin_end, split_end));
......
...@@ -742,7 +742,7 @@ struct FmhaFwdSplitKVKernel ...@@ -742,7 +742,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
else else
{ {
......
...@@ -343,6 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS ...@@ -343,6 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// moving k_dram_window is an in-page-block operation, so there is // moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here. // no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
// ensure LDS access by Q is done before the over-writting by K
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
do do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment