Unverified Commit 43879b89 authored by rocking's avatar rocking Committed by GitHub
Browse files

Small refactor (#1246)



* Remove kIsFp8

* Extract alias

* Fix K, V and corresponding acc type

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent ad1597c4
...@@ -33,7 +33,6 @@ struct FmhaFwdKernel ...@@ -33,7 +33,6 @@ struct FmhaFwdKernel
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>; using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>; using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>; using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>; using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
......
...@@ -49,13 +49,6 @@ struct BlockFmhaPipelineProblem ...@@ -49,13 +49,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kIsFp8 =
(std::is_same_v<QDataType, fp8_t> || std::is_same_v<QDataType, bf8_t>)&&(
std::is_same_v<KDataType, fp8_t> ||
std::is_same_v<KDataType, bf8_t>)&&(std::is_same_v<VDataType, fp8_t> ||
std::is_same_v<VDataType, bf8_t>)&&std::
is_same_v<SaccDataType, float> &&
std::is_same_v<OaccDataType, float>;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -31,7 +31,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -31,7 +31,6 @@ struct BlockFmhaPipelineQRKSVS
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
......
...@@ -32,7 +32,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -32,7 +32,6 @@ struct BlockFmhaPipelineQRKSVSAsync
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
......
...@@ -31,7 +31,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -31,7 +31,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
......
...@@ -30,7 +30,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -30,7 +30,6 @@ struct BlockFmhaPipelineQSKSVS
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = false; static constexpr bool kQLoadOnce = false;
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kIsFp8 = Problem::kIsFp8;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
......
...@@ -97,16 +97,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -97,16 +97,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{ {
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
} }
else if constexpr(Problem::kIsFp8) else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{ {
constexpr index_t swizzle_factor = 4; // TODO: hard coded here // TODO: hard coded here. Otherwise, it may incorrect result
return WarpGemmImpl< constexpr index_t swizzle_factor = 4;
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType, swizzle_factor>{};
typename Problem::KDataType>, } // TODO - bf8_t
2,
swizzle_factor>>{};
}
}(); }();
using BlockGemmPolicy = using BlockGemmPolicy =
...@@ -221,16 +220,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -221,16 +220,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{ {
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
} }
else if constexpr(Problem::kIsFp8) else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{ {
constexpr index_t swizzle_factor = 4; // TODO: hard coded here // TODO: hard coded here. Otherwise, it may incorrect result
return WarpGemmImpl< constexpr index_t swizzle_factor = 4;
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType, swizzle_factor>{};
typename Problem::KDataType>, } // TODO - bf8_t
2,
swizzle_factor>>{};
}
}(); }();
using BlockGemmPolicy = using BlockGemmPolicy =
...@@ -920,12 +918,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -920,12 +918,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
Problem::BlockFmhaShape::kK1>>; Problem::BlockFmhaShape::kK1>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8) if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::VDataType, fp8_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{ {
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
// return // return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB< // WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename // WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
......
...@@ -102,4 +102,11 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< ...@@ -102,4 +102,11 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>,
2,
swizzle_factor>>;
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment