Commit f6ceef78 authored by ThomasNing's avatar ThomasNing
Browse files

merge with the develop branch

parents 536c5458 25935b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
{
return false;
}
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
......
......@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
......@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct Problem
......@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
......@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value))
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
{
......
......@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
......@@ -56,7 +55,7 @@ __global__ void
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx9__))
}
template <typename GridwiseGemm,
......@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -93,7 +91,7 @@ __global__ void
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx9__))
}
template <typename ALayout,
......@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
......@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
......@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
......@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......
......@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1>
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvFwdToGemm
{
private:
......@@ -46,10 +47,10 @@ struct TransformConvFwdToGemm
}
template <typename ConvDimsType>
static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
......@@ -59,7 +60,7 @@ struct TransformConvFwdToGemm
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const index_t N = a_g_n_c_wis_lengths[I1];
const IndexType N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
......@@ -70,7 +71,7 @@ struct TransformConvFwdToGemm
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
......@@ -98,6 +99,53 @@ struct TransformConvFwdToGemm
public:
__host__ __device__ constexpr TransformConvFwdToGemm() {}
template <typename TransformConvFwdToGemmBase>
__host__ __device__
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
: N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
DiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DiStride_)},
HiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HiStride_)},
WiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WiStride_)},
DoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DoStride_)},
HoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HoStride_)},
WoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WoStride_)},
XStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.XStride_)},
CStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorA_)},
CStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorB_)},
KStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorB_)},
KStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorC_)},
NStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorA_)},
NStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorC_)},
GStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorA_)},
GStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorB_)},
GStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorC_)},
ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
{
}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
......@@ -126,6 +174,8 @@ struct TransformConvFwdToGemm
DiStride_{I1},
HiStride_{I1},
WiStride_{a_g_n_c_wis_strides[I3]},
DoStride_{I1},
HoStride_{I1},
WoStride_{c_g_n_k_wos_strides[I3]},
XStride_{b_g_k_c_xs_strides[I3]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
......@@ -133,6 +183,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
......@@ -150,10 +201,10 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
......@@ -164,7 +215,6 @@ struct TransformConvFwdToGemm
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Wo_;
}
template <typename ConvDimsType,
......@@ -195,6 +245,8 @@ struct TransformConvFwdToGemm
DiStride_{I1},
HiStride_{a_g_n_c_wis_strides[I3]},
WiStride_{a_g_n_c_wis_strides[I4]},
DoStride_{I1},
HoStride_{c_g_n_k_wos_strides[I3]},
WoStride_{c_g_n_k_wos_strides[I4]},
XStride_{b_g_k_c_xs_strides[I4]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
......@@ -202,6 +254,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
......@@ -219,10 +272,10 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
......@@ -233,7 +286,6 @@ struct TransformConvFwdToGemm
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Ho_ * Wo_;
}
template <typename ConvDimsType,
......@@ -264,6 +316,8 @@ struct TransformConvFwdToGemm
DiStride_{a_g_n_c_wis_strides[I3]},
HiStride_{a_g_n_c_wis_strides[I4]},
WiStride_{a_g_n_c_wis_strides[I5]},
DoStride_{c_g_n_k_wos_strides[I3]},
HoStride_{c_g_n_k_wos_strides[I4]},
WoStride_{c_g_n_k_wos_strides[I5]},
XStride_{b_g_k_c_xs_strides[I5]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
......@@ -271,6 +325,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
......@@ -288,10 +343,10 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
......@@ -302,7 +357,122 @@ struct TransformConvFwdToGemm
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
}
__host__ bool AreDescriptorsSmallerThan2GB() const
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t in_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
}
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
c_right_offset = (Do_ / 2) * DoStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
c_right_offset = (Ho_ / 2) * HoStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
conv_to_gemm_transformer_right.Wi_ =
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
c_right_offset = (Wo_ / 2) * WoStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
......@@ -320,20 +490,27 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
......@@ -527,20 +704,29 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
......@@ -759,20 +945,34 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
......@@ -1119,45 +1319,70 @@ struct TransformConvFwdToGemm
}
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 3 &&
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNWK>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_pass_through_transform(NDoHoWo_),
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
......@@ -1167,7 +1392,7 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo_),
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
......@@ -1175,45 +1400,146 @@ struct TransformConvFwdToGemm
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
index_t NDimSp = NDimSpatial,
typename std::enable_if<
NDimSp == 2 && (is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
return out_gemmm_gemmn_desc;
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc =
make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(NStrideTensorC_,
HoStride_,
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
public:
index_t N_;
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<
NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
private:
const index_t Di_, Hi_, Wi_;
const index_t Do_, Ho_, Wo_;
const index_t Z_, Y_, X_;
const index_t K_, C_;
const index_t DiStride_, HiStride_, WiStride_;
const index_t WoStride_;
const index_t XStride_;
const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
const index_t NStrideTensorA_;
const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
const index_t InRightPadD_, InRightPadH_, InRightPadW_;
const index_t ZYX_;
index_t NDoHoWo_;
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc =
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(NStrideTensorC_,
DoStride_,
HoStride_,
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
IndexType K_, C_;
IndexType DiStride_, HiStride_, WiStride_;
IndexType DoStride_, HoStride_, WoStride_;
IndexType XStride_;
IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
IndexType NStrideTensorA_, NStrideTensorC_;
IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType ZYX_;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
......@@ -1230,17 +1556,17 @@ struct TransformConv
if(NDimSpatial == 2)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK, 2>();
}
else if(NDimSpatial == 3)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK, 3>();
}
else if(NDimSpatial == 1)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK, 1>();
}
}
};
......
......@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset);
}
template <typename T, index_t N>
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
T* addr)
{
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, half_t>::value)
{
vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
tmp.template AsType<bhalf2_t>()[i]);
});
}
#endif
}
template <typename T, index_t N>
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
......@@ -907,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
if constexpr(is_same<T, bhalf_t>::value)
{
if(dst_thread_element_valid)
{
amd_global_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, p_dst_wave + dst_thread_element_offset);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
}
// buffer_atomic_max requires:
......
......@@ -358,13 +358,15 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
......
......@@ -44,7 +44,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
// convert to bitwise
using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
T_bitwise x_bitwise = bit_cast<T_bitwise>(x);
// unpack the input, depends on datatype
head = x_bitwise & NumericUtils<X>::head_mask;
......@@ -165,7 +165,7 @@ In this case, the fp16 mantissa should be shift left by 1 */
if(out_exponent > max_exp)
{
if(clip)
if constexpr(clip)
{
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
......@@ -196,18 +196,17 @@ __host__ __device__ Y run_cast_from_f8(X x)
// prepare the codes
constexpr X nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
using T_bitwise = typename NumericUtils<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
constexpr Y Inf = bit_cast<Y>(Inf_bitwise);
constexpr Y NegInf = bit_cast<Y>(NegInf_bitwise);
constexpr Y NaN = bit_cast<Y>(NaN_bitwise);
constexpr Y Neg0 = bit_cast<Y>(Neg0_bitwise);
// check if x is 0.0
if(x == 0)
......@@ -235,11 +234,12 @@ __host__ __device__ Y run_cast_from_f8(X x)
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
if constexpr((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) &&
!negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
return bit_cast<Y>(retval);
}
// subnormal input
......@@ -263,7 +263,7 @@ __host__ __device__ Y run_cast_from_f8(X x)
}
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
return bit_cast<Y>(retval);
}
} // namespace
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
namespace ck {
// Define the common macro for gfx94x models
......@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
......@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift>
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
......@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths;
UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
......@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
idx_low(number<1>{}) =
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<RightShift>::value;
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
......@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
......@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
......@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
{
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
return xor_t<LowLengths>{low_lengths};
}
template <typename LowLength, typename OffsetLength>
......
......@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
......@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
......
......@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
......
......@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private:
struct ull2
{
......
......@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
......@@ -286,11 +286,226 @@ struct BlockDropout
});
}
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = true;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_)
: ph(seed,
offset + (i_batch * nheads + i_head) * get_warp_size() +
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
// except headdim256.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
if constexpr(IsWG32)
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
}
else
{
if constexpr(IsWG32)
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
......@@ -305,30 +520,177 @@ struct BlockDropout
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute
auto randval =
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval.kThreadElementSpaceSize == 16);
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start = (start_m0_idx / WG::kM) + i_m0;
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
auto randval_lds_read_window =
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
randval_lds_window.get_window_lengths(),
randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t[i_random_idx++];
constexpr auto p_idx0 =
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
});
});
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval = make_static_distributed_tensor<uint8_t>(
MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start, block_col_start;
if constexpr(IsWG32)
{
block_row_start = (start_m0_idx / WG::kM) + i_m0;
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
}
else
{
block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t* random_uint8_t_;
if constexpr(MBwdWG16SingleIterCheck)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const index_t start_idx =
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
}
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
......@@ -337,19 +699,19 @@ struct BlockDropout
});
});
// save to Global
if(is_store_randval)
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0});
}
});
if(is_store_randval)
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
}
});
if(is_store_randval)
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
}
......@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
} // namespace ck_tile
......@@ -23,13 +23,9 @@
namespace ck_tile {
template <typename TilePartitioner_,
typename FmhaPipeline_,
typename KGradEpiloguePipeline_,
typename VGradEpiloguePipeline_>
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
struct FmhaBwdDQDKDVKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
......@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
// clang-format off
template <typename T> struct t2s;
......@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr0 = typename bfs::Gemm0BlockWarps;
using gbr1 = typename bfs::Gemm1BlockWarps;
using gbr4 = typename bfs::Gemm4BlockWarps;
using gwt0 = typename bfs::Gemm0WarpTile;
using gwt1 = typename bfs::Gemm1WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
......@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return
_SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
_TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
"r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" );
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "" );
#undef _SS_
#undef _TS_
// clang-format on
......@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const void* lse_ptr;
const void* do_ptr;
const void* d_ptr;
void* dq_ptr;
void* dq_acc_ptr;
void* dk_ptr;
void* dv_ptr;
......@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t num_head_q;
ck_tile::index_t nhead_ratio_qk;
float raw_scale;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale;
#endif
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
......@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
};
struct FmhaBwdCommonBiasKargs
......@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float rp_undrop = 1;
float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
......@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdBatchModeKargs
: FmhaBwdCommonKargs,
......@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs<0>>>,
std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>
std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
};
......@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs<0>>>,
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
......@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dq_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
......@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
......@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
......@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{{q_ptr,
......@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr,
do_ptr,
d_ptr,
dq_ptr,
dq_acc_ptr,
dk_ptr,
dv_ptr,
seqlen_q,
......@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q,
nhead_ratio_qk,
scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>),
#endif
stride_q,
stride_k,
stride_v,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
nhead_stride_q,
......@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v,
nhead_stride_do,
nhead_stride_lsed,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv};
......@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
kargs.is_store_randval = s_randval;
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
......@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dq_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
......@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
......@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{{q_ptr,
......@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr,
do_ptr,
d_ptr,
dq_ptr,
dq_acc_ptr,
dk_ptr,
dv_ptr,
-1, // seqlen will be updated by another pointer
......@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q,
nhead_ratio_qk,
scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>),
#endif
stride_q,
stride_k,
stride_v,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
nhead_stride_q,
......@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v,
nhead_stride_do,
nhead_stride_lsed,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
......@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.is_store_randval = s_randval;
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
......@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_);
return dim3(
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k);
const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
......@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_do = 0;
long_index_t batch_offset_lsed = 0;
long_index_t batch_offset_dq_acc = 0;
long_index_t batch_offset_dk = 0;
long_index_t batch_offset_dv = 0;
long_index_t batch_offset_dbias = 0;
......@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = query_start;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
......@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias = key_start;
}
if constexpr(kHasDropout)
if constexpr(kIsStoreRandval)
{
batch_offset_randval = query_start * kargs.stride_randval;
}
......@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
}
if constexpr(kHasDropout)
if constexpr(kIsStoreRandval)
{
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
......@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do;
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
batch_offset_dk;
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
batch_offset_dv;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
......@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
const auto q_dram = [&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto qt_dram_naive =
transform_tensor_view(q_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_q),
make_pass_through_transform(kargs.seqlen_q)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto qt_dram = [&]() {
if constexpr(FmhaPipeline::kQTLoadOnce)
{
return pad_tensor_view(
qt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kM0>{}),
sequence<kPadHeadDimQ, kPadSeqLenQ>{});
}
else
{
return pad_tensor_view(
qt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kK3>{}),
sequence<kPadHeadDimQ, kPadSeqLenQ>{});
}
}();
const auto q_dram = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
......@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
const auto k_dram = [&]() {
if constexpr(FmhaPipeline::kKLoadOnce)
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
}();
const auto kt_dram_naive =
transform_tensor_view(k_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_q),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto kt_dram = [&]() {
if constexpr(FmhaPipeline::kKTLoadOnce)
{
return pad_tensor_view(
kt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimQ, kPadSeqLenK>{});
}
else
{
return pad_tensor_view(
kt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kK4>{}),
sequence<kPadHeadDimQ, kPadSeqLenK>{});
}
}();
const auto k_dram = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
......@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
if constexpr(FmhaPipeline::kVLoadOnce)
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK2>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}();
const auto lse_dram = [&]() {
......@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_do, 1),
number<FmhaPipeline::kAlignmentOGrad>{},
number<1>{});
const auto do_dram = [&]() {
if constexpr(FmhaPipeline::kOGradLoadOnce)
{
return pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}
else
{
return pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK2>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}
}();
const auto dot_dram_naive =
transform_tensor_view(do_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_q)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto dot_dram = [&]() {
if constexpr(FmhaPipeline::kOGradTLoadOnce)
{
return pad_tensor_view(
dot_dram_naive,
make_tuple(number<FmhaPipeline::kVHeaddim>{}, number<FmhaPipeline::kM0>{}),
sequence<kPadHeadDimV, kPadSeqLenQ>{});
}
else
{
return pad_tensor_view(
dot_dram_naive,
make_tuple(number<FmhaPipeline::kVHeaddim>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenQ>{});
}
}();
auto dq_dram = [&]() {
const auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto do_dram = pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
auto qt_dram_window =
make_tile_window(qt_dram,
[&]() {
if constexpr(FmhaPipeline::kQTLoadOnce)
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kM0>{});
else
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kK3>{});
}(),
{0, 0});
auto k_dram_window = make_tile_window(
k_dram,
[&]() {
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{i_n0, 0});
auto kt_dram_window =
make_tile_window(kt_dram,
[&]() {
if constexpr(FmhaPipeline::kKTLoadOnce)
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kN0>{});
else
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kK4>{});
}(),
{0, i_n0});
auto v_dram_window = make_tile_window(
v_dram,
[&]() {
if constexpr(FmhaPipeline::kVLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kVHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK2>{});
}(),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
{i_n0, 0});
auto do_dram_window = make_tile_window(
do_dram,
[&]() {
if constexpr(FmhaPipeline::kOGradLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kVHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK2>{});
}(),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
{0, 0});
auto dot_dram_window =
make_tile_window(dot_dram,
[&]() {
if constexpr(FmhaPipeline::kOGradTLoadOnce)
return make_tuple(number<FmhaPipeline::kVHeaddim>{},
number<FmhaPipeline::kM0>{});
else
return make_tuple(number<FmhaPipeline::kVHeaddim>{},
number<FmhaPipeline::kK1>{});
}(),
{0, 0});
auto dq_dram_window = make_tile_window(
dq_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
if constexpr(kIsDeterministic)
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
else
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
}();
auto lse_dram_window =
make_tile_window(lse_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
......@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask)
{
return make_alibi_from_lr_mask<AccDataType, false>(slope,
......@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel
}();
// dropout
float rp_undrop = 1;
float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
bool is_store_randval = false;
float rp_undrop = 1;
float scale_rp_undrop = 1;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
scale_rp_undrop = kargs.scale_rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
rp_undrop = kargs.rp_undrop;
scale_rp_undrop = kargs.scale_rp_undrop;
}
BlockDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t,
is_store_randval);
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
else
{
return FmhaDropout{};
};
}();
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout)
if constexpr(kIsStoreRandval)
{
RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
......@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}();
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
qt_dram_window,
k_dram_window,
kt_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
dot_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
......@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask,
position_encoding,
kargs.raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs.scale,
#endif
rp_undrop,
scale_rp_undrop,
smem_ptr,
......@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template <typename TilePartitioner_, typename FmhaBwdOGradDotO_>
template <typename FmhaBwdOGradDotO_>
struct FmhaBwdOGradDotOKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
......@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_d;
ck_tile::index_t batch_stride_d;
};
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
{
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_d;
};
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
......@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d,
batch_stride_d},
nhead_stride_d},
batch_stride_do,
batch_stride_o};
batch_stride_o,
batch_stride_d};
return kargs;
}
......@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d,
ck_tile::index_t batch_stride_d)
ck_tile::index_t nhead_stride_d)
{
Kargs kargs{{o_ptr,
do_ptr,
......@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d,
batch_stride_d},
nhead_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
return kargs;
......@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q);
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
......@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
batch_offset_d = query_start;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
}
};
template <typename FmhaBwdConvertQGrad_>
struct FmhaBwdConvertQGradKernel
{
using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
using AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>;
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>;
static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimQ) n += "d";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s<QGradDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + (kIsDeterministic ? "_deterministic" : "") + "_" +
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template <ck_tile::index_t I>
struct FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaBwdConvertQGradCommonKargs
{
const void* dq_acc_ptr;
void* dq_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dq_acc;
};
struct FmhaBwdConvertQGradDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdConvertQGradBatchModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
{
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dq_acc;
};
struct FmhaBwdConvertQGradGroupModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode,
FmhaBwdConvertQGradGroupModeKargs,
FmhaBwdConvertQGradBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
dq_ptr,
seqlen_q,
seqlen_k,
hdim_q,
stride_dq,
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
batch_stride_dq,
batch_stride_dq_acc};
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
dq_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
stride_dq,
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
long_index_t batch_offset_dq = 0;
long_index_t batch_offset_dq_acc = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if constexpr(kIsDeterministic)
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
}
else
{
batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
batch_offset_dq;
// dQAcc/dQ DRAM and DRAM window
const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
if constexpr(kIsDeterministic)
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
auto dq_dram = [&]() {
auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(dq_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
auto dq_acc_dram_window = [&]() {
if constexpr(kIsDeterministic)
{
return make_tile_window(
dq_acc_dram,
make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
{0, i_m0, 0});
}
else
{
return make_tile_window(
dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
}
}();
auto dq_dram_window =
make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
if constexpr(kIsDeterministic)
{
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
}
else
{
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
......@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
......@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
......@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
......@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
batch_offset_lse = query_start;
}
if constexpr(kHasDropout)
{
......
......@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) +
(kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
......@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
......@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
......@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE)
{
......@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
......@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
......@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
// for simplicity, batch stride we just modify the pointer
......
......@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
......@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
......@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
......@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v};
batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
......@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
......@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
......@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
......@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment