Commit 9bb4ab41 authored by Junhao's avatar Junhao
Browse files

Merge branch 'junhzhan/fa-ifu-mqa' of...

Merge branch 'junhzhan/fa-ifu-mqa' of https://github.com/ROCmSoftwarePlatform/composable_kernel into junhzhan/fa-ifu-mqa
parents 980b8835 5ff2d646
......@@ -119,6 +119,15 @@ struct GemmGemmPadder
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
}
// C[M, Gemm1N] = C[M, N]
template <typename C0Desc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadC0Descriptor_M_N(const C0Desc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
......
......@@ -31,7 +31,7 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32;
}
#if FLASH_ATTENTION_INTERNAL_USE_RTN
#ifdef USE_RTN_BF16_CONVERT
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment