Commit c701b3ff authored by Anthony Chang's avatar Anthony Chang
Browse files

refactor attention padding to better fit in unit tests

This streamlines usage where "ResetNaNToMinusInf" is now hidden from user facing device op.
Also added compile-time conditionals that load OOB value as NaN only after padding is enabled
parent f2e11f93
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16)
......@@ -49,14 +49,9 @@ using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
// When using padded DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle kernel, 2 specs should be set:
// 1. GemmSpecialization should be set to MNPadding(or NPadding in future)
// 2. Acc0ElementOp should be set to ScaleAndResetNaNToMinusInfinity
// Otherwise, wrong result may be produced.
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAndResetNaNToMinusInfinity;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
......
......@@ -662,7 +662,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
LoopSched,
matrix_padder.PadN>;
// Argument
// FIXME: constness
......
......@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -198,6 +199,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -617,7 +622,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
LoopSched,
matrix_padder.PadN>;
// Argument
struct Argument : public BaseArgument
......
......@@ -75,7 +75,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
bool PadN>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -330,6 +331,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool Pred>
struct ElementOpPredicatedResetNaNToMinusInf;
template <>
struct ElementOpPredicatedResetNaNToMinusInf<true>
{
template <typename ElementOp, typename OutT, typename InT>
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
{
if(ck::math::isnan(x))
{
y = -ck::NumericLimits<float>::Infinity();
}
else
{
op(y, x);
}
}
};
template <>
struct ElementOpPredicatedResetNaNToMinusInf<false>
{
template <typename ElementOp, typename OutT, typename InT>
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
{
op(y, x);
}
};
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -348,14 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto a_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
const auto b_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN());
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -721,8 +758,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
num_k_block_main_loop);
// Acc0 elementwise Op
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment