"docs/vscode:/vscode.git/clone" did not exist on "98d46a3f085cc7529c9ddf3b7c4918cc248e9162"
Unverified Commit 58e7f37f authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Undo padding-flag changes in fmha_fwd_kernel.hpp (#1725)

parent 86990558
...@@ -998,14 +998,14 @@ struct FmhaFwdKernel ...@@ -998,14 +998,14 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() { const auto k_dram = [&]() {
...@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel ...@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{}); sequence<kPadSeqLenK, kPadHeadDimQ>{});
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel ...@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
else else
{ {
...@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel ...@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<false, kPadSeqLenK>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
}(); }();
...@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel ...@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentBias>{}, number<FmhaPipeline::kAlignmentBias>{},
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(bias_dram_naive,
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{}); bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}(); }();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment