Commit 812521df authored by wangshaojie6's avatar wangshaojie6
Browse files

add k padding

parent 12e7df12
......@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
......
......@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
......
......@@ -196,7 +196,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
// static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
......
......@@ -384,20 +384,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
const auto b_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -415,6 +405,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return;
}
// fetch origin N dim(before padding)
const index_t n_raw = b_grid_desc_bk0_n_bk1.GetTransforms()[I0].GetUpperLengths()[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......@@ -794,20 +787,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf,
num_k_block_main_loop);
if constexpr(!MaskOutUpperTriangle)
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
}
else
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
......@@ -826,29 +807,43 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if(n_global > m_global)
if constexpr(MaskOutUpperTriangle)
{
if(n_global > m_global || n_global > n_raw)
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
#else
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(acc_offset),
acc_element_op,
}
}
else
{
// ignore m_global;
if(n_global > n_raw)
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
#endif
}
}
});
});
});
});
}
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment