Commit c2ea75ed authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Add no variance support to block_norm_reduce

parent b16fad32
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
// attention! 2 vector type could be just the same type
// fp64
using fp64_t = double;
using fp64x1_t = double __attribute__((ext_vector_type(1)));
using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32
using fp32_t = float;
using fp32x1_t = float __attribute__((ext_vector_type(1)));
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8)));
......@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
// using fp16_t = ...
using fp16x1_t = _Float16 __attribute__((ext_vector_type(1)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
......@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x1_t = bf16_raw_t __attribute__((ext_vector_type(1)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
......@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...
using int32x1_t = int32_t __attribute__((ext_vector_type(1)));
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
......@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x1_t = uint32_t __attribute__((ext_vector_type(1)));
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
......@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x1_t = int16_t __attribute__((ext_vector_type(1)));
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
......@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
// using uint16_t
using uint16x1_t = uint16_t __attribute__((ext_vector_type(1)));
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
......@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// using int8_t
using int8x1_t = int8_t __attribute((ext_vector_type(1)));
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
......@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8
// using uint8_t
using uint8x1_t = uint8_t __attribute((ext_vector_type(1)));
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
......@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using fp8x1_t = fp8_raw_t __attribute((ext_vector_type(1)));
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
......@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x1_t = bf8_raw_t __attribute((ext_vector_type(1)));
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
......@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else
// f8
// using fp8_t
using fp8x1_t = fp8_t __attribute((ext_vector_type(1)));
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
......@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x1_t = bf8_t __attribute((ext_vector_type(1)));
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduce<P_>{};
......@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
......@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
......@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -11,14 +11,14 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct BlockNormReduce
{
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
CK_TILE_DEVICE constexpr BlockNormReduce() {}
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kComputeVariance = Problem::kComputeVariance;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
private:
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_
......@@ -26,16 +26,18 @@ struct BlockNormReduce
template <typename XDistributedTensor_,
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& cur_count_, // -> prefer init as zero
const int& max_count_)
CK_TILE_DEVICE void Impl(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& cur_count_, // -> prefer init as zero
const int& max_count_)
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
constexpr bool computeVariance =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
if(cur_count_ < max_count_)
......@@ -48,22 +50,57 @@ struct BlockNormReduce
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
if(kWelford)
{
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
x,
cur_count_,
constant<kFastFDiv>{});
if constexpr(computeVariance)
{
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
x,
cur_count_,
constant<kFastFDiv>{});
}
else
{
welford_update(
mean_tensor(out_dstr_idx), x, cur_count_, constant<kFastFDiv>{});
}
}
else
{
mean_tensor(out_dstr_idx) += x;
var_tensor(out_dstr_idx) += x * x;
if constexpr(computeVariance)
{
var_tensor(out_dstr_idx) += x * x;
}
}
});
}
});
}
public:
CK_TILE_DEVICE constexpr BlockNormReduce() {}
template <typename XDistributedTensor_,
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& cur_count_,
const int& max_count_)
{
Impl(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
}
template <typename XDistributedTensor_, typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
int& cur_count_,
const int& max_count_)
{
Impl(x_tensor, mean_tensor, null_tensor{}, cur_count_, max_count_);
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeMeanVarBlockTile()
{
......@@ -82,12 +119,13 @@ struct BlockNormReduce
return tensor;
}
template <typename XDistributedTensor_>
template <bool kComputeVariance, typename XDistributedTensor_>
CK_TILE_DEVICE auto
operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_)
{
auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
auto var_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
auto var_tensor =
kComputeVariance ? MakeMeanVarBlockTile<XDistributedTensor_>() : null_tensor{};
clear_tile(mean_tensor);
clear_tile(var_tensor);
......@@ -100,13 +138,15 @@ struct BlockNormReduce
template <typename Problem_, typename Policy_ = void>
struct BlockNormReduceSync
{
using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using Problem = remove_cvref_t<Problem_>;
static constexpr bool kComputeVariance = Problem::kComputeVariance;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
private:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
Impl(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
{
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
......@@ -120,19 +160,23 @@ struct BlockNormReduceSync
constexpr index_t idim_p_lane = NDimP - 1;
constexpr bool computeVariance =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
const int original_count = count;
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
auto v_local_var = var_tensor.get_thread_buffer()[i];
auto v_local_var = computeVariance ? var_tensor.get_thread_buffer()[i] : 0;
auto v_local_count = original_count;
// cross-lane reduce for replication
......@@ -161,24 +205,39 @@ struct BlockNormReduceSync
// pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
const auto v_remote_var =
computeVariance ? warp_shuffle(v_local_var, src_lane) : 0;
if(kWelford)
{
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// norm_reduce merge
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
if constexpr(computeVariance)
{
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
welford_merge(v_local_mean,
v_local_count,
v_remote_mean,
v_remote_count,
constant<kFastFDiv>{});
}
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
if constexpr(computeVariance)
{
v_local_var += v_remote_var;
}
}
});
}
......@@ -192,16 +251,34 @@ struct BlockNormReduceSync
}
});
}
public:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
{
Impl(mean_tensor, var_tensor, count);
}
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count)
{
Impl(mean_tensor, null_tensor{}, count);
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockNormReduceCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
static constexpr bool kComputeVariance = Problem::kComputeVariance;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using smem_dtype =
std::conditional_t<kWelford,
typename std::conditional<kComputeVariance, fp32x4_t, fp32x2_t>::type,
typename std::conditional<kComputeVariance, fp32x2_t, fp32x1_t>::type>;
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
......@@ -234,7 +311,12 @@ struct BlockNormReduceCrossWarpSync
{
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
// data need to exchange is very small, we just pack mean+var+count -> 4dword
// data need to exchange is very small, we just pack mean+var+count -> 4dword if var should
// be calculated, or pack mean+count -> 2dword if only mean is taken into account.
// Additionally, count is not exchanged if Welford algorithm is not used.
constexpr index_t num_dw =
kWelford ? (kComputeVariance ? 4 : 2) : (kComputeVariance ? 2 : 1);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
......@@ -251,25 +333,32 @@ struct BlockNormReduceCrossWarpSync
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
return num_warps * 4 * thread_buf_size * sizeof(float);
return num_warps * num_dw * thread_buf_size * sizeof(float);
}
private:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& count,
void* smem)
CK_TILE_DEVICE void Impl(MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& count,
void* smem)
{
using DataType = typename MeanDistributedTensor_::DataType;
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail;
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr bool computeVariance =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
static_assert(
(computeVariance == false) ||
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
// Note: we always pack everything into fp32x4
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
......@@ -289,10 +378,17 @@ struct BlockNormReduceCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i) {
smem_dtype local_scratch_;
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
if(kWelford)
if constexpr(kComputeVariance)
{
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
if constexpr(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
}
}
else if constexpr(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
local_scratch_[1] = bit_cast<float>(count);
}
smem_ptr[smem_offset + i * num_warps] = local_scratch_;
});
......@@ -317,40 +413,74 @@ struct BlockNormReduceCrossWarpSync
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]);
int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
auto v_local_var = kComputeVariance ? bit_cast<DataType>(v_local[1]) : 0;
int v_local_count = kWelford ? (kComputeVariance ? bit_cast<int>(v_local[2])
: bit_cast<int>(v_local[1]))
: 0;
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
if(kWelford)
const auto v_remote_var = kComputeVariance ? bit_cast<DataType>(v_remote[1]) : 0;
if constexpr(kWelford)
{
const auto v_remote_count = bit_cast<int>(v_remote[2]);
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
if constexpr(kComputeVariance)
{
const auto v_remote_count = bit_cast<int>(v_remote[2]);
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
const auto v_remote_count = bit_cast<int>(v_remote[1]);
welford_merge(v_local_mean,
v_local_count,
v_remote_mean,
v_remote_count,
constant<kFastFDiv>{});
}
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
if constexpr(kComputeVariance)
{
v_local_var += v_remote_var;
}
}
});
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var;
if(kWelford)
if constexpr(kWelford)
{
count = v_local_count;
}
});
}
public:
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor,
int& count,
void* smem)
{
Impl(mean_tensor, var_tensor, count, smem);
}
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count, void* smem)
{
Impl(mean_tensor, null_tensor{}, count, smem);
}
};
// compute the max count for a last dim reduce
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -10,15 +10,17 @@ namespace ck_tile {
template <typename XDataType_,
typename ComputeDataType_,
typename BlockShape_,
bool kComputeVariance_,
bool kFastFDiv_,
bool kWelford_>
struct BlockNormReduceProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kComputeVariance = kComputeVariance_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment