"...composable_kernel_rocm.git" did not exist on "5ae893c0d331a3b7f88a3c8e00cb6dfac9bf45a8"
Unverified Commit 43879b89 authored by rocking's avatar rocking Committed by GitHub
Browse files

Small refactor (#1246)



* Remove kIsFp8

* Extract alias

* Fix K, V and corresponding acc type

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent ad1597c4
......@@ -27,13 +27,12 @@ struct FmhaFwdKernel
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
......
......@@ -49,13 +49,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kIsFp8 =
(std::is_same_v<QDataType, fp8_t> || std::is_same_v<QDataType, bf8_t>)&&(
std::is_same_v<KDataType, fp8_t> ||
std::is_same_v<KDataType, bf8_t>)&&(std::is_same_v<VDataType, fp8_t> ||
std::is_same_v<VDataType, bf8_t>)&&std::
is_same_v<SaccDataType, float> &&
std::is_same_v<OaccDataType, float>;
};
} // namespace ck_tile
......@@ -31,7 +31,6 @@ struct BlockFmhaPipelineQRKSVS
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize;
......
......@@ -32,7 +32,6 @@ struct BlockFmhaPipelineQRKSVSAsync
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize;
......
......@@ -31,7 +31,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize;
......
......@@ -30,7 +30,6 @@ struct BlockFmhaPipelineQSKSVS
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = false;
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize;
......
......@@ -97,16 +97,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
}
else if constexpr(Problem::kIsFp8)
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
......@@ -221,16 +220,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
}
else if constexpr(Problem::kIsFp8)
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
......@@ -920,12 +918,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
Problem::BlockFmhaShape::kK1>>;
auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8)
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::VDataType, fp8_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
......
......@@ -102,4 +102,11 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>,
2,
swizzle_factor>>;
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment