Commit c2ea75ed authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Add no variance support to block_norm_reduce

parent b16fad32
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t< ...@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
// attention! 2 vector type could be just the same type // attention! 2 vector type could be just the same type
// fp64 // fp64
using fp64_t = double; using fp64_t = double;
using fp64x1_t = double __attribute__((ext_vector_type(1)));
using fp64x2_t = double __attribute__((ext_vector_type(2))); using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4))); using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32 // fp32
using fp32_t = float; using fp32_t = float;
using fp32x1_t = float __attribute__((ext_vector_type(1)));
using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4))); using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8))); using fp32x8_t = float __attribute__((ext_vector_type(8)));
...@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64))); ...@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16 // fp16
// using fp16_t = ... // using fp16_t = ...
using fp16x1_t = _Float16 __attribute__((ext_vector_type(1)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4))); using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8))); using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
...@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); ...@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16 // bf16
// using bf16_t = ... // using bf16_t = ...
using bf16x1_t = bf16_raw_t __attribute__((ext_vector_type(1)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
...@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); ...@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32 // i32
// using int32_t = ... // using int32_t = ...
using int32x1_t = int32_t __attribute__((ext_vector_type(1)));
using int32x2_t = int32_t __attribute__((ext_vector_type(2))); using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4))); using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8))); using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
...@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64))); ...@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32 // u32
// using uint32_t = ... // using uint32_t = ...
using uint32x1_t = uint32_t __attribute__((ext_vector_type(1)));
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4))); using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8))); using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
...@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64))); ...@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16 // i16
// using int16_t = ... // using int16_t = ...
using int16x1_t = int16_t __attribute__((ext_vector_type(1)));
using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4))); using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8))); using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
...@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64))); ...@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16 // u16
// using uint16_t // using uint16_t
using uint16x1_t = uint16_t __attribute__((ext_vector_type(1)));
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2))); using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4))); using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8))); using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
...@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64))); ...@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8 // i8
// using int8_t // using int8_t
using int8x1_t = int8_t __attribute((ext_vector_type(1)));
using int8x2_t = int8_t __attribute((ext_vector_type(2))); using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4))); using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8))); using int8x8_t = int8_t __attribute((ext_vector_type(8)));
...@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64))); ...@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8 // ui8
// using uint8_t // using uint8_t
using uint8x1_t = uint8_t __attribute((ext_vector_type(1)));
using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
...@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); ...@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE #if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8 // f8
// using fp8_t // using fp8_t
using fp8x1_t = fp8_raw_t __attribute((ext_vector_type(1)));
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4))); using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8))); using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
...@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64))); ...@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8 // bf8
// using bf8_t // using bf8_t
using bf8x1_t = bf8_raw_t __attribute((ext_vector_type(1)));
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2))); using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4))); using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
...@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64))); ...@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else #else
// f8 // f8
// using fp8_t // using fp8_t
using fp8x1_t = fp8_t __attribute((ext_vector_type(1)));
using fp8x2_t = fp8_t __attribute((ext_vector_type(2))); using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4))); using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8))); using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
...@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64))); ...@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8 // bf8
// using bf8_t // using bf8_t
using bf8x1_t = bf8_t __attribute((ext_vector_type(1)));
using bf8x2_t = bf8_t __attribute((ext_vector_type(2))); using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4))); using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8))); using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
return BlockNormReduce<P_>{}; return BlockNormReduce<P_>{};
...@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
...@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
...@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -11,14 +11,14 @@ namespace ck_tile { ...@@ -11,14 +11,14 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockNormReduce struct BlockNormReduce
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType; using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType; using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kComputeVariance = Problem::kComputeVariance;
static constexpr bool kWelford = Problem::kWelford; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
CK_TILE_DEVICE constexpr BlockNormReduce() {}
private:
// [CAUSION] - max_count_ is to deal with the padding problem // [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different // max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_ // calculation of max_count_
...@@ -26,16 +26,18 @@ struct BlockNormReduce ...@@ -26,16 +26,18 @@ struct BlockNormReduce
template <typename XDistributedTensor_, template <typename XDistributedTensor_,
typename MeanDistributedTensor_, typename MeanDistributedTensor_,
typename VarDistributedTensor_> typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, CK_TILE_DEVICE void Impl(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor, MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor, VarDistributedTensor_& var_tensor,
int& cur_count_, // -> prefer init as zero int& cur_count_, // -> prefer init as zero
const int& max_count_) const int& max_count_)
{ {
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans(); constexpr auto spans = XDistributedTensor_::get_distributed_spans();
constexpr bool computeVariance =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
if(cur_count_ < max_count_) if(cur_count_ < max_count_)
...@@ -48,22 +50,57 @@ struct BlockNormReduce ...@@ -48,22 +50,57 @@ struct BlockNormReduce
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]); auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
if(kWelford) if(kWelford)
{ {
welford_update(mean_tensor(out_dstr_idx), if constexpr(computeVariance)
var_tensor(out_dstr_idx), {
x, welford_update(mean_tensor(out_dstr_idx),
cur_count_, var_tensor(out_dstr_idx),
constant<kFastFDiv>{}); x,
cur_count_,
constant<kFastFDiv>{});
}
else
{
welford_update(
mean_tensor(out_dstr_idx), x, cur_count_, constant<kFastFDiv>{});
}
} }
else else
{ {
mean_tensor(out_dstr_idx) += x; mean_tensor(out_dstr_idx) += x;
var_tensor(out_dstr_idx) += x * x; if constexpr(computeVariance)
{
var_tensor(out_dstr_idx) += x * x;
}
} }
}); });
} }
}); });
} }
public:
CK_TILE_DEVICE constexpr BlockNormReduce() {}
template <typename XDistributedTensor_,
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& cur_count_,
const int& max_count_)
{
Impl(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
}
template <typename XDistributedTensor_, typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
int& cur_count_,
const int& max_count_)
{
Impl(x_tensor, mean_tensor, null_tensor{}, cur_count_, max_count_);
}
template <typename XDistributedTensor_> template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeMeanVarBlockTile() CK_TILE_DEVICE static auto MakeMeanVarBlockTile()
{ {
...@@ -82,12 +119,13 @@ struct BlockNormReduce ...@@ -82,12 +119,13 @@ struct BlockNormReduce
return tensor; return tensor;
} }
template <typename XDistributedTensor_> template <bool kComputeVariance, typename XDistributedTensor_>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_) operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_)
{ {
auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>(); auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
auto var_tensor = MakeMeanVarBlockTile<XDistributedTensor_>(); auto var_tensor =
kComputeVariance ? MakeMeanVarBlockTile<XDistributedTensor_>() : null_tensor{};
clear_tile(mean_tensor); clear_tile(mean_tensor);
clear_tile(var_tensor); clear_tile(var_tensor);
...@@ -100,13 +138,15 @@ struct BlockNormReduce ...@@ -100,13 +138,15 @@ struct BlockNormReduce
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockNormReduceSync struct BlockNormReduceSync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kComputeVariance = Problem::kComputeVariance;
static constexpr bool kWelford = Problem::kWelford; static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
private:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_> template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) Impl(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
{ {
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -120,19 +160,23 @@ struct BlockNormReduceSync ...@@ -120,19 +160,23 @@ struct BlockNormReduceSync
constexpr index_t idim_p_lane = NDimP - 1; constexpr index_t idim_p_lane = NDimP - 1;
constexpr bool computeVariance =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id()); // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx = // const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
const int original_count = count; const int original_count = count;
// loop over thread data // loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) { static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local_mean = mean_tensor.get_thread_buffer()[i]; auto v_local_mean = mean_tensor.get_thread_buffer()[i];
auto v_local_var = var_tensor.get_thread_buffer()[i]; auto v_local_var = computeVariance ? var_tensor.get_thread_buffer()[i] : 0;
auto v_local_count = original_count; auto v_local_count = original_count;
// cross-lane reduce for replication // cross-lane reduce for replication
...@@ -161,24 +205,39 @@ struct BlockNormReduceSync ...@@ -161,24 +205,39 @@ struct BlockNormReduceSync
// pull data from remote lane // pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane); const auto v_remote_var =
computeVariance ? warp_shuffle(v_local_var, src_lane) : 0;
if(kWelford) if(kWelford)
{ {
const auto v_remote_count = warp_shuffle(v_local_count, src_lane); const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// norm_reduce merge // norm_reduce merge
welford_merge(v_local_mean, if constexpr(computeVariance)
v_local_var, {
v_local_count, welford_merge(v_local_mean,
v_remote_mean, v_local_var,
v_remote_var, v_local_count,
v_remote_count, v_remote_mean,
constant<kFastFDiv>{}); v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
welford_merge(v_local_mean,
v_local_count,
v_remote_mean,
v_remote_count,
constant<kFastFDiv>{});
}
} }
else else
{ {
v_local_mean += v_remote_mean; v_local_mean += v_remote_mean;
v_local_var += v_remote_var; if constexpr(computeVariance)
{
v_local_var += v_remote_var;
}
} }
}); });
} }
...@@ -192,16 +251,34 @@ struct BlockNormReduceSync ...@@ -192,16 +251,34 @@ struct BlockNormReduceSync
} }
}); });
} }
public:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
{
Impl(mean_tensor, var_tensor, count);
}
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count)
{
Impl(mean_tensor, null_tensor{}, count);
}
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct BlockNormReduceCrossWarpSync struct BlockNormReduceCrossWarpSync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape; using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv; static constexpr bool kComputeVariance = Problem::kComputeVariance;
static constexpr bool kWelford = Problem::kWelford; static constexpr bool kFastFDiv = Problem::kFastFDiv;
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>; static constexpr bool kWelford = Problem::kWelford;
using smem_dtype =
std::conditional_t<kWelford,
typename std::conditional<kComputeVariance, fp32x4_t, fp32x2_t>::type,
typename std::conditional<kComputeVariance, fp32x2_t, fp32x1_t>::type>;
template <typename MeanDistributedTensor_> template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps() CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
...@@ -234,7 +311,12 @@ struct BlockNormReduceCrossWarpSync ...@@ -234,7 +311,12 @@ struct BlockNormReduceCrossWarpSync
{ {
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>(); // constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
// data need to exchange is very small, we just pack mean+var+count -> 4dword // data need to exchange is very small, we just pack mean+var+count -> 4dword if var should
// be calculated, or pack mean+count -> 2dword if only mean is taken into account.
// Additionally, count is not exchanged if Welford algorithm is not used.
constexpr index_t num_dw =
kWelford ? (kComputeVariance ? 4 : 2) : (kComputeVariance ? 2 : 1);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem // we need to store all data from every wave into smem
...@@ -251,25 +333,32 @@ struct BlockNormReduceCrossWarpSync ...@@ -251,25 +333,32 @@ struct BlockNormReduceCrossWarpSync
// //
// -> also store data from every wave into LDS // -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize; constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
return num_warps * 4 * thread_buf_size * sizeof(float); return num_warps * num_dw * thread_buf_size * sizeof(float);
} }
private:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_> template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, CK_TILE_DEVICE void Impl(MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor, VarDistributedTensor_& var_tensor,
int& count, int& count,
void* smem) void* smem)
{ {
using DataType = typename MeanDistributedTensor_::DataType; using DataType = typename MeanDistributedTensor_::DataType;
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
// using DstrEncode = typename Dstr::DstrEncode; // using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail; // using DstrEncodeDetail = typename DstrEncode::detail;
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>, constexpr bool computeVariance =
"wrong!"); !std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
static_assert(
(computeVariance == false) ||
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
// Note: we always pack everything into fp32x4 // Note: we always pack everything into fp32x4
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem); smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
...@@ -289,10 +378,17 @@ struct BlockNormReduceCrossWarpSync ...@@ -289,10 +378,17 @@ struct BlockNormReduceCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i) { static_for<0, thread_buf_size, 1>{}([&](auto i) {
smem_dtype local_scratch_; smem_dtype local_scratch_;
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]); local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]); if constexpr(kComputeVariance)
if(kWelford) {
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
if constexpr(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
}
}
else if constexpr(kWelford)
{ {
local_scratch_[2] = bit_cast<float>(count); local_scratch_[1] = bit_cast<float>(count);
} }
smem_ptr[smem_offset + i * num_warps] = local_scratch_; smem_ptr[smem_offset + i * num_warps] = local_scratch_;
}); });
...@@ -317,40 +413,74 @@ struct BlockNormReduceCrossWarpSync ...@@ -317,40 +413,74 @@ struct BlockNormReduceCrossWarpSync
// TODO: use descriptor for this // TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps]; auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]); auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]); auto v_local_var = kComputeVariance ? bit_cast<DataType>(v_local[1]) : 0;
int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0; int v_local_count = kWelford ? (kComputeVariance ? bit_cast<int>(v_local[2])
: bit_cast<int>(v_local[1]))
: 0;
// further reduce mean/var // further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{}; constexpr auto i_1 = number<i_1_n1 + 1>{};
const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]); const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
const auto v_remote_var = bit_cast<DataType>(v_remote[1]); const auto v_remote_var = kComputeVariance ? bit_cast<DataType>(v_remote[1]) : 0;
if(kWelford) if constexpr(kWelford)
{ {
const auto v_remote_count = bit_cast<int>(v_remote[2]); if constexpr(kComputeVariance)
{
welford_merge(v_local_mean, const auto v_remote_count = bit_cast<int>(v_remote[2]);
v_local_var, welford_merge(v_local_mean,
v_local_count, v_local_var,
v_remote_mean, v_local_count,
v_remote_var, v_remote_mean,
v_remote_count, v_remote_var,
constant<kFastFDiv>{}); v_remote_count,
constant<kFastFDiv>{});
}
else
{
const auto v_remote_count = bit_cast<int>(v_remote[1]);
welford_merge(v_local_mean,
v_local_count,
v_remote_mean,
v_remote_count,
constant<kFastFDiv>{});
}
} }
else else
{ {
v_local_mean += v_remote_mean; v_local_mean += v_remote_mean;
v_local_var += v_remote_var; if constexpr(kComputeVariance)
{
v_local_var += v_remote_var;
}
} }
}); });
mean_tensor.get_thread_buffer()(i_0) = v_local_mean; mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var; var_tensor.get_thread_buffer()(i_0) = v_local_var;
if(kWelford) if constexpr(kWelford)
{
count = v_local_count; count = v_local_count;
}
}); });
} }
public:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& count,
void* smem)
{
Impl(mean_tensor, var_tensor, count, smem);
}
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count, void* smem)
{
Impl(mean_tensor, null_tensor{}, count, smem);
}
}; };
// compute the max count for a last dim reduce // compute the max count for a last dim reduce
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -10,15 +10,17 @@ namespace ck_tile { ...@@ -10,15 +10,17 @@ namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename BlockShape_, typename BlockShape_,
bool kComputeVariance_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_> bool kWelford_>
struct BlockNormReduceProblem struct BlockNormReduceProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kComputeVariance = kComputeVariance_;
static constexpr bool kWelford = kWelford_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment