Commit efab74a3 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'gfx950' into lwpck-2619

parents 86950b3a bcef33c1
...@@ -11,7 +11,6 @@ namespace ck_tile { ...@@ -11,7 +11,6 @@ namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
...@@ -444,6 +443,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -444,6 +443,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// bf16 // bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
...@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp8 // fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
......
...@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma ...@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
sequence<>, "Multi-block on both M & N directions is not supported");
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
tuple<sequence<2, 1>>, {
tuple<sequence<0, 0>>, return tile_distribution_encoding<
sequence<2>, sequence<>,
sequence<1>>; tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
using CWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
sequence<Impl::kCNLane>>, {
tuple<sequence<1, 2>>, return tile_distribution_encoding<
tuple<sequence<1, 0>>, sequence<>,
sequence<1, 1>, tuple<sequence<Impl::kBNLane>,
sequence<0, 2>>; sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kBNBlock * Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
template <bool post_nop_ = false> template <bool post_nop_ = false>
...@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
sequence<>, "Multi-block on both M & N directions is not supported");
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
tuple<sequence<2, 1>>, {
tuple<sequence<0, 0>>, return tile_distribution_encoding<
sequence<2>, sequence<>,
sequence<1>>; tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
using CWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kCNLane>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>, {
tuple<sequence<2, 1>>, return tile_distribution_encoding<
tuple<sequence<1, 0>>, sequence<>,
sequence<2, 2>, tuple<sequence<Impl::kAMLane>,
sequence<0, 2>>; sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
template <bool post_nop_ = false> template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
...@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
...@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane), tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
......
...@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 8; static constexpr index_t kK = 8;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kN = 16; static constexpr index_t kN = 16;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16; static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16; static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4; static constexpr index_t kABKLane = 4;
...@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 4;
static constexpr index_t kN = 64;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 16;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 64;
static constexpr index_t kN = 4;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 16;
static constexpr index_t kBNBlock = 1;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// Bf16 // Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 8; static constexpr index_t kK = 8;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kN = 16; static constexpr index_t kN = 16;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16; static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16; static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4; static constexpr index_t kABKLane = 4;
...@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 4;
static constexpr index_t kN = 64;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 16;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 64;
static constexpr index_t kN = 4;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 16;
static constexpr index_t kBNBlock = 1;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// FP8 // FP8
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 ...@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
......
...@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float ...@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
...@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float ...@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,8 @@ struct Layernorm2dFwdHostArgs ...@@ -14,7 +14,8 @@ struct Layernorm2dFwdHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -42,15 +43,16 @@ struct Layernorm2dFwd ...@@ -42,15 +43,16 @@ struct Layernorm2dFwd
using Epilogue = remove_cvref_t<Epilogue_>; using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>; using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X // for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -67,6 +69,7 @@ struct Layernorm2dFwd ...@@ -67,6 +69,7 @@ struct Layernorm2dFwd
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -81,7 +84,8 @@ struct Layernorm2dFwd ...@@ -81,7 +84,8 @@ struct Layernorm2dFwd
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -107,7 +111,8 @@ struct Layernorm2dFwd ...@@ -107,7 +111,8 @@ struct Layernorm2dFwd
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_sm_scale,
hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
hargs.p_y, hargs.p_y,
...@@ -152,6 +157,7 @@ struct Layernorm2dFwd ...@@ -152,6 +157,7 @@ struct Layernorm2dFwd
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name; if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name; if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
...@@ -165,7 +171,7 @@ struct Layernorm2dFwd ...@@ -165,7 +171,7 @@ struct Layernorm2dFwd
base_str += _SS_("_") + _SS_(t2s<YDataType>::name); base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name); base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name); base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
...@@ -228,6 +234,27 @@ struct Layernorm2dFwd ...@@ -228,6 +234,27 @@ struct Layernorm2dFwd
} }
}(); }();
const auto x_bias_window = [&]() {
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XBiasDataType*>(kargs.p_x_bias),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -329,18 +356,18 @@ struct Layernorm2dFwd ...@@ -329,18 +356,18 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto x_scale_window = [&]() { auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
const auto win_ = [&]() { const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>( const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale), static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n), make_tuple(kargs.n),
number<Vector_N>{}); number<Vector_N>{});
return pad_tensor_view(tmp_0_, return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}), make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad sequence<false>{}); // sm_scale no need pad
}(); }();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0}); return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
} }
...@@ -371,13 +398,14 @@ struct Layernorm2dFwd ...@@ -371,13 +398,14 @@ struct Layernorm2dFwd
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window, x_residual_window,
x_bias_window,
gamma_window, gamma_window,
beta_window, beta_window,
y_window, y_window,
y_residual_window, y_residual_window,
mean_window, mean_window,
inv_std_window, inv_std_window,
x_scale_window, sm_scale_window,
y_scale_window, y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelford<P_>{}; return BlockNormReduce<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordSync<P_>{}; return BlockNormReduceSync<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordCrossWarpSync<P_>{}; return BlockNormReduceCrossWarpSync<P_>{};
} }
template <typename Problem> template <typename Problem>
...@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
using block_welford = BlockWelford<P_>; using block_welford = BlockNormReduce<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile = using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>()); decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockWelfordCrossWarpSync<Problem>() return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>(); .template GetSmemSize<mean_var_block_tile>();
} }
else else
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -37,6 +38,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -37,6 +38,8 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -54,24 +57,26 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -54,24 +57,26 @@ struct Layernorm2dFwdPipelineOnePass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window_, YWindow& y_window_,
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_, const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
...@@ -80,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -80,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window( const auto beta_window = make_tile_window(
...@@ -89,23 +96,38 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -89,23 +96,38 @@ struct Layernorm2dFwdPipelineOnePass
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_welford_cross_warp_sync = auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(mean);
clear_tile(var);
// load gamma/beta (TODO: support no gamma/beta?) // load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
...@@ -117,12 +139,21 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -117,12 +139,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc)); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
} }
// compute welford each-thread->cross-lane->cross-warp // compute reduce each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(acc, cur_count, max_count); block_norm_reduce(acc, mean, var, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_norm_reduce_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{}); if(kWelford)
{
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
}
else
{
sweep_tile(mean, [&](auto idx) {
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
});
}
// compute inv-std // compute inv-std
auto inv_std = tile_elementwise_in( auto inv_std = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
...@@ -153,14 +184,13 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -153,14 +184,13 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
ln(idx) = ln_;
}); });
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem); Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,28 +8,30 @@ ...@@ -8,28 +8,30 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename XBiasDataType_,
typename GammaDataType_, typename GammaDataType_,
typename BetaDataType_, typename BetaDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
typename Traits_> typename Traits_>
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -53,32 +56,37 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -53,32 +56,37 @@ struct Layernorm2dFwdPipelineTwoPass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window,
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/, const SmoothScaleWindow& /*sm_scale_window*/,
YScaleWindow& /*y_scale_window*/, YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem, void* smem,
Epilogue) const Epilogue) const
{ {
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window( auto beta_window = make_tile_window(
...@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass
int max_count = int max_count =
(num_n_tile_iteration - 1) * count_per_iter + (num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_welford_cross_warp_sync = auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window))); using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
move_tile_window(x_bias_window, {Block_N});
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
...@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(y_residual_window, {0, Block_N}); move_tile_window(y_residual_window, {0, Block_N});
} }
} }
block_welford(acc, mean, var, cur_count, max_count); block_norm_reduce(acc, mean, var, cur_count, max_count);
} }
block_welford_sync(mean, var, cur_count); block_norm_reduce_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{}); block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std // compute inv-std
...@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
...@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation // layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x); const auto x_bias = load_tile(x_bias_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
...@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N}); move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
......
...@@ -7,6 +7,19 @@ ...@@ -7,6 +7,19 @@
namespace ck_tile { namespace ck_tile {
enum class Layernorm2dXBiasEnum
{
NO_BIAS = 0,
// add bias before fused add
ADD_BIAS = 1,
};
// clang-format off
template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
// clang-format on
enum class Layernorm2dFusedAddEnum enum class Layernorm2dFusedAddEnum
{ {
NO_ADD = 0, NO_ADD = 0,
...@@ -40,7 +53,9 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT ...@@ -40,7 +53,9 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template <bool kPadN_, template <bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
Layernorm2dXBiasEnum kXbias_,
Layernorm2dFusedAddEnum kFusedAdd_, Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_> Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits struct Layernorm2dFwdTraits
...@@ -48,7 +63,9 @@ struct Layernorm2dFwdTraits ...@@ -48,7 +63,9 @@ struct Layernorm2dFwdTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -4,22 +4,23 @@ ...@@ -4,22 +4,23 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockWelford struct BlockNormReduce
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType; using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType; using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
CK_TILE_DEVICE constexpr BlockWelford() {} CK_TILE_DEVICE constexpr BlockNormReduce() {}
// [CAUSION] - max_count_ is to deal with the padding problem // [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different // max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_ // calculation of max_count_
// -> use block_welford_calculate_max_count to compute // -> use block_welford_calculate_max_count to compute
template <typename XDistributedTensor_, template <typename XDistributedTensor_,
...@@ -40,18 +41,24 @@ struct BlockWelford ...@@ -40,18 +41,24 @@ struct BlockWelford
if(cur_count_ < max_count_) if(cur_count_ < max_count_)
{ {
++cur_count_; ++cur_count_;
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]); auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
if(kWelford)
welford_update(mean_tensor(out_dstr_idx), {
var_tensor(out_dstr_idx), welford_update(mean_tensor(out_dstr_idx),
x, var_tensor(out_dstr_idx),
cur_count_, x,
constant<kFastFDiv>{}); cur_count_,
constant<kFastFDiv>{});
}
else
{
mean_tensor(out_dstr_idx) += x;
var_tensor(out_dstr_idx) += x * x;
}
}); });
} }
}); });
...@@ -91,10 +98,11 @@ struct BlockWelford ...@@ -91,10 +98,11 @@ struct BlockWelford
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockWelfordSync struct BlockNormReduceSync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
template <typename MeanDistributedTensor_, typename VarDistributedTensor_> template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void CK_TILE_DEVICE void
...@@ -152,36 +160,48 @@ struct BlockWelfordSync ...@@ -152,36 +160,48 @@ struct BlockWelfordSync
(number<lid_over_rid_derivative << istage.value>{}.value); (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane // pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane); const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
const auto v_remote_count = warp_shuffle(v_local_count, src_lane); if(kWelford)
{
// welford merge const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
welford_merge(v_local_mean,
v_local_var, // norm_reduce merge
v_local_count, welford_merge(v_local_mean,
v_remote_mean, v_local_var,
v_remote_var, v_local_count,
v_remote_count, v_remote_mean,
constant<kFastFDiv>{}); v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
}); });
} }
}); });
mean_tensor.get_thread_buffer()(i) = v_local_mean; mean_tensor.get_thread_buffer()(i) = v_local_mean;
var_tensor.get_thread_buffer()(i) = v_local_var; var_tensor.get_thread_buffer()(i) = v_local_var;
if(kWelford)
count = v_local_count; {
count = v_local_count;
}
}); });
} }
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockWelfordCrossWarpSync struct BlockNormReduceCrossWarpSync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape; using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
template <typename MeanDistributedTensor_> template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps() CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
...@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync ...@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
// Note: we always pack everything into fp32x4 // Note: we always pack everything into fp32x4
fp32x4_t* smem_ptr = reinterpret_cast<fp32x4_t*>(smem); smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
const index_t lane_id = get_lane_id(); const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id(); const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>(); constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
...@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync ...@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if(lane_id == 0) if(lane_id == 0)
{ {
static_for<0, thread_buf_size, 1>{}([&](auto i) { static_for<0, thread_buf_size, 1>{}([&](auto i) {
fp32x4_t local_scratch_; smem_dtype local_scratch_;
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]); local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]); local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
local_scratch_[2] = bit_cast<float>(count); if(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
}
smem_ptr[smem_offset + i * num_warps] = local_scratch_; smem_ptr[smem_offset + i * num_warps] = local_scratch_;
}); });
} }
...@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync ...@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :) // load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps; index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps; index_t local_smem_os = local_warp_id * num_reduce_warps;
fp32x4_t all_scratch[thread_buf_size * num_reduce_warps]; smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] = all_scratch[i_0 * num_reduce_warps + i_1] =
...@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync ...@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this // TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps]; auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]); auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]); auto v_local_var = bit_cast<DataType>(v_local[1]);
auto v_local_count = bit_cast<int>(v_local[2]); int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
// further reduce mean/var // further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{}; constexpr auto i_1 = number<i_1_n1 + 1>{};
const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]); const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
const auto v_remote_var = bit_cast<DataType>(v_remote[1]); const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
const auto v_remote_count = bit_cast<int>(v_remote[2]); if(kWelford)
{
welford_merge(v_local_mean, const auto v_remote_count = bit_cast<int>(v_remote[2]);
v_local_var,
v_local_count, welford_merge(v_local_mean,
v_remote_mean, v_local_var,
v_remote_var, v_local_count,
v_remote_count, v_remote_mean,
constant<kFastFDiv>{}); v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
}); });
mean_tensor.get_thread_buffer()(i_0) = v_local_mean; mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var; var_tensor.get_thread_buffer()(i_0) = v_local_var;
if(kWelford)
count = v_local_count; count = v_local_count;
}); });
} }
}; };
......
...@@ -7,13 +7,18 @@ ...@@ -7,13 +7,18 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_> template <typename XDataType_,
struct BlockWelfordProblem typename ComputeDataType_,
typename BlockShape_,
bool kFastFDiv_,
bool kWelford_>
struct BlockNormReduceProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,5 +8,6 @@ ...@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment