Unverified Commit 8b49f207 authored by Max Podkorytov's avatar Max Podkorytov Committed by GitHub
Browse files

Merge branch 'develop' into fa-h512

parents 0d59f474 a6b761c3
...@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 8; static constexpr index_t kK = 8;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kN = 16; static constexpr index_t kN = 16;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16; static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16; static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4; static constexpr index_t kABKLane = 4;
...@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 4;
static constexpr index_t kN = 64;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 16;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 64;
static constexpr index_t kN = 4;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 16;
static constexpr index_t kBNBlock = 1;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// Bf16 // Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 8; static constexpr index_t kK = 8;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kN = 16; static constexpr index_t kN = 16;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16; static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16; static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4; static constexpr index_t kABKLane = 4;
...@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 4;
static constexpr index_t kN = 64;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 16;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 64;
static constexpr index_t kN = 4;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 16;
static constexpr index_t kBNBlock = 1;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// FP8 // FP8
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 ...@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
......
...@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float ...@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
...@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float ...@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelford<P_>{}; return BlockNormReduce<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordSync<P_>{}; return BlockNormReduceSync<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordCrossWarpSync<P_>{}; return BlockNormReduceCrossWarpSync<P_>{};
} }
template <typename Problem> template <typename Problem>
...@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
using block_welford = BlockWelford<P_>; using block_welford = BlockNormReduce<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile = using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>()); decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockWelfordCrossWarpSync<Problem>() return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>(); .template GetSmemSize<mean_var_block_tile>();
} }
else else
......
...@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_welford_cross_warp_sync = auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(mean);
clear_tile(var);
// load gamma/beta (TODO: support no gamma/beta?) // load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
...@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc)); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
} }
// compute welford each-thread->cross-lane->cross-warp // compute reduce each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(acc, cur_count, max_count); block_norm_reduce(acc, mean, var, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_norm_reduce_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{}); if(kWelford)
{
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
}
else
{
sweep_tile(mean, [&](auto idx) {
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
});
}
// compute inv-std // compute inv-std
auto inv_std = tile_elementwise_in( auto inv_std = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
...@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
ln(idx) = ln_;
}); });
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
......
...@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
void* smem, void* smem,
Epilogue) const Epilogue) const
{ {
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
...@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
int max_count = int max_count =
(num_n_tile_iteration - 1) * count_per_iter + (num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_welford_cross_warp_sync = auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window))); using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
...@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(y_residual_window, {0, Block_N}); move_tile_window(y_residual_window, {0, Block_N});
} }
} }
block_welford(acc, mean, var, cur_count, max_count); block_norm_reduce(acc, mean, var, cur_count, max_count);
} }
block_welford_sync(mean, var, cur_count); block_norm_reduce_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{}); block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std // compute inv-std
......
...@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT ...@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template <bool kPadN_, template <bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_, Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_> Layernorm2dFusedQuantEnum kFusedQuant_>
...@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits ...@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -4,22 +4,23 @@ ...@@ -4,22 +4,23 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockWelford struct BlockNormReduce
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType; using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType; using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
CK_TILE_DEVICE constexpr BlockWelford() {} CK_TILE_DEVICE constexpr BlockNormReduce() {}
// [CAUSION] - max_count_ is to deal with the padding problem // [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different // max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_ // calculation of max_count_
// -> use block_welford_calculate_max_count to compute // -> use block_welford_calculate_max_count to compute
template <typename XDistributedTensor_, template <typename XDistributedTensor_,
...@@ -40,18 +41,24 @@ struct BlockWelford ...@@ -40,18 +41,24 @@ struct BlockWelford
if(cur_count_ < max_count_) if(cur_count_ < max_count_)
{ {
++cur_count_; ++cur_count_;
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]); auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
if(kWelford)
welford_update(mean_tensor(out_dstr_idx), {
var_tensor(out_dstr_idx), welford_update(mean_tensor(out_dstr_idx),
x, var_tensor(out_dstr_idx),
cur_count_, x,
constant<kFastFDiv>{}); cur_count_,
constant<kFastFDiv>{});
}
else
{
mean_tensor(out_dstr_idx) += x;
var_tensor(out_dstr_idx) += x * x;
}
}); });
} }
}); });
...@@ -91,10 +98,11 @@ struct BlockWelford ...@@ -91,10 +98,11 @@ struct BlockWelford
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockWelfordSync struct BlockNormReduceSync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
template <typename MeanDistributedTensor_, typename VarDistributedTensor_> template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void CK_TILE_DEVICE void
...@@ -152,36 +160,48 @@ struct BlockWelfordSync ...@@ -152,36 +160,48 @@ struct BlockWelfordSync
(number<lid_over_rid_derivative << istage.value>{}.value); (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane // pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane); const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
const auto v_remote_count = warp_shuffle(v_local_count, src_lane); if(kWelford)
{
// welford merge const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
welford_merge(v_local_mean,
v_local_var, // norm_reduce merge
v_local_count, welford_merge(v_local_mean,
v_remote_mean, v_local_var,
v_remote_var, v_local_count,
v_remote_count, v_remote_mean,
constant<kFastFDiv>{}); v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
}); });
} }
}); });
mean_tensor.get_thread_buffer()(i) = v_local_mean; mean_tensor.get_thread_buffer()(i) = v_local_mean;
var_tensor.get_thread_buffer()(i) = v_local_var; var_tensor.get_thread_buffer()(i) = v_local_var;
if(kWelford)
count = v_local_count; {
count = v_local_count;
}
}); });
} }
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockWelfordCrossWarpSync struct BlockNormReduceCrossWarpSync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape; using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
template <typename MeanDistributedTensor_> template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps() CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
...@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync ...@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
// Note: we always pack everything into fp32x4 // Note: we always pack everything into fp32x4
fp32x4_t* smem_ptr = reinterpret_cast<fp32x4_t*>(smem); smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
const index_t lane_id = get_lane_id(); const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id(); const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>(); constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
...@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync ...@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if(lane_id == 0) if(lane_id == 0)
{ {
static_for<0, thread_buf_size, 1>{}([&](auto i) { static_for<0, thread_buf_size, 1>{}([&](auto i) {
fp32x4_t local_scratch_; smem_dtype local_scratch_;
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]); local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]); local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
local_scratch_[2] = bit_cast<float>(count); if(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
}
smem_ptr[smem_offset + i * num_warps] = local_scratch_; smem_ptr[smem_offset + i * num_warps] = local_scratch_;
}); });
} }
...@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync ...@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :) // load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps; index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps; index_t local_smem_os = local_warp_id * num_reduce_warps;
fp32x4_t all_scratch[thread_buf_size * num_reduce_warps]; smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] = all_scratch[i_0 * num_reduce_warps + i_1] =
...@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync ...@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this // TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps]; auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]); auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]); auto v_local_var = bit_cast<DataType>(v_local[1]);
auto v_local_count = bit_cast<int>(v_local[2]); int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
// further reduce mean/var // further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{}; constexpr auto i_1 = number<i_1_n1 + 1>{};
const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]); const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
const auto v_remote_var = bit_cast<DataType>(v_remote[1]); const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
const auto v_remote_count = bit_cast<int>(v_remote[2]); if(kWelford)
{
welford_merge(v_local_mean, const auto v_remote_count = bit_cast<int>(v_remote[2]);
v_local_var,
v_local_count, welford_merge(v_local_mean,
v_remote_mean, v_local_var,
v_remote_var, v_local_count,
v_remote_count, v_remote_mean,
constant<kFastFDiv>{}); v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
}); });
mean_tensor.get_thread_buffer()(i_0) = v_local_mean; mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var; var_tensor.get_thread_buffer()(i_0) = v_local_var;
if(kWelford)
count = v_local_count; count = v_local_count;
}); });
} }
}; };
......
...@@ -7,13 +7,18 @@ ...@@ -7,13 +7,18 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_> template <typename XDataType_,
struct BlockWelfordProblem typename ComputeDataType_,
typename BlockShape_,
bool kFastFDiv_,
bool kWelford_>
struct BlockNormReduceProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs ...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // output row_stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -58,14 +59,21 @@ struct Smoothquant ...@@ -58,14 +59,21 @@ struct Smoothquant
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // out row_stride
}; };
using Hargs = SmoothquantHostArgs; using Hargs = SmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{ return Kargs{hargs.p_x,
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride}; hargs.p_xscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.y_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -116,7 +124,7 @@ struct Smoothquant ...@@ -116,7 +124,7 @@ struct Smoothquant
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -157,7 +165,7 @@ struct Smoothquant ...@@ -157,7 +165,7 @@ struct Smoothquant
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy), static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment