Commit 3c305fee authored by carlushuang's avatar carlushuang
Browse files

support fused dynamic-quant

parent 7fb9b2b6
...@@ -195,7 +195,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -195,7 +195,7 @@ float layernorm2d_fwd_(const S& s, A a)
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>; using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType, using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>; ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
......
...@@ -203,18 +203,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -203,18 +203,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(fused_sweep == 1) if(fused_sweep == 1)
{ {
auto dquant_functor = [&](int m_, auto o_, auto acc_) { auto dquant_functor = [&](int m_, auto& o_, const auto& acc_) {
int N_ = acc_.mDesc.get_lengths()[1]; int N_ = acc_.mDesc.get_lengths()[1];
ComputeDataType absmax = 0; ComputeDataType absmax = static_cast<ComputeDataType>(0);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
const auto a = abs(acc_(m_, n_)); const auto a = ck_tile::abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax; absmax = a > absmax ? a : absmax;
} }
y_scale_host_ref(m_) = absmax / 127.0; // printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
y_scale_host_ref(m_) = ck_tile::type_convert<ScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
o_(m_, n_) = static_cast<YDataType>(acc_(m_, n_) / y_scale_host_ref(m_)); o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
} }
}; };
......
...@@ -9,4 +9,5 @@ ...@@ -9,4 +9,5 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
#pragma once #pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Generic2dBlockShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t BlockSize = BlockSize_;
};
} // namespace ck_tile
...@@ -4,4 +4,5 @@ ...@@ -4,4 +4,5 @@
#pragma once #pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -6,4 +6,5 @@ ...@@ -6,4 +6,5 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -18,12 +18,17 @@ struct DynamicQuantEpilogueTraits ...@@ -18,12 +18,17 @@ struct DynamicQuantEpilogueTraits
}; };
// this epilogue just store out a M*N matrix, row major // this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, typename YScaleDataType_, typename ODataType_, typename Traits_> template <typename AccDataType_,
typename YScaleDataType_,
typename ODataType_,
typename BlockShape_,
typename Traits_>
struct DynamicQuantEpilogueProblem struct DynamicQuantEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
}; };
...@@ -34,42 +39,81 @@ struct DynamicQuantEpilogue ...@@ -34,42 +39,81 @@ struct DynamicQuantEpilogue
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
static constexpr bool kPadM = Problem::Traits::kPadM; static constexpr bool kPadM = Problem::Traits::kPadM;
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool UseRawStore = Problem::Traits::UseRawStore; static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
static constexpr bool UseMax3 = Problem::Traits::UseMax3; static constexpr bool UseMax3 = Problem::Traits::UseMax3;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2d<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dSync<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
return reduce_crosswarp_sync.GetSmemSize();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile> template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile) const OAccTile& o_acc_tile,
void* smem)
{ {
// compute row max auto reduce = GetBlockReduce2d();
auto reduce_row_absmax = BlockReduce2D{o_acc_tile, type_convert<AccDataType>(0)}; auto reduce_sync = GetBlockReduce2dSync();
auto row_absmax = [&]() { auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
if constexpr(UseMax3 && std::is_same_v<AccDataType, float>)
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto row_absmax = [&]() {
constexpr auto y_size_per_row =
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
number<1>{});
// constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}];
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
{ {
const auto f_max = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; // fast max3 implementation
// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
// float rtn; float rtn;
// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
// : "=v"(rtn) : "=v"(rtn)
// : "v"(acc_), "v"(v_0_), "v"(v_1_)); : "v"(acc_), "v"(v_0_), "v"(v_1_));
// return rtn; return rtn;
// }; };
// return reduce_row_absmax(f_max3, f_max, sequence<1, 2>{}); return reduce(o_acc_tile, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
return reduce_row_absmax(f_max);
} }
else else
{ {
const auto f_max = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; return reduce(o_acc_tile, type_convert<AccDataType>(0), f_absmax);
return reduce_row_absmax(f_max);
} }
}(); }();
reduce_sync(row_absmax, f_absmax);
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
#if 0
sweep_tile(row_absmax, [&](auto idx) {
auto ddd = row_absmax[idx];
printf("tid:%d, absmax:%f\n", static_cast<int>(threadIdx.x), ddd);
});
#endif
// here y_scale is Acc TYpe, need convert to YScale type later // here y_scale is Acc TYpe, need convert to YScale type later
auto y_scale = tile_elementwise_in( auto y_scale = tile_elementwise_in(
...@@ -80,15 +124,23 @@ struct DynamicQuantEpilogue ...@@ -80,15 +124,23 @@ struct DynamicQuantEpilogue
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale)); store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
auto o_acc_scaled_tile =
make_static_distributed_tensor<AccDataType>(o_acc_tile.get_tile_distribution());
sweep_tile(o_acc_tile, [&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_scaled_tile(idx) = o_acc_tile[idx] / y_scale(row_id);
});
// TODO: this is ugly // TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN)) if constexpr(UseRawStore && (kPadM || kPadN))
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_scaled_tile));
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_scaled_tile));
} }
} }
}; };
......
...@@ -43,4 +43,5 @@ ...@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -37,4 +37,5 @@ ...@@ -37,4 +37,5 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -6,4 +6,5 @@ ...@@ -6,4 +6,5 @@
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -10,4 +10,5 @@ ...@@ -10,4 +10,5 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -147,7 +147,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -147,7 +147,7 @@ struct Layernorm2dFwdPipelineOnePass
if constexpr(kFusedSweep == Layernorm2dFusedSweepEnum::DYNAMIC_QUANT) if constexpr(kFusedSweep == Layernorm2dFusedSweepEnum::DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, y_scale_window, ln); Epilogue{}(y_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
...@@ -5,4 +5,5 @@ ...@@ -5,4 +5,5 @@
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -7,4 +7,5 @@ ...@@ -7,4 +7,5 @@
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -16,8 +16,8 @@ namespace ck_tile { ...@@ -16,8 +16,8 @@ namespace ck_tile {
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension) // synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true> template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {}) bool_constant<WithBroadcast> = {})
{ {
using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, ...@@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
*/ */
template <typename AccDistributedTensor_, typename ReduceFunc> template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func)
{ {
using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -175,9 +175,9 @@ template <typename AccDistributedTensor_, ...@@ -175,9 +175,9 @@ template <typename AccDistributedTensor_,
index_t... InReduceDims, index_t... InReduceDims,
typename ReduceFunc> typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
const InDistributedTensor_& in_tensor, const InDistributedTensor_& in_tensor,
sequence<InReduceDims...>, sequence<InReduceDims...>,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func)
{ {
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
...@@ -250,9 +250,9 @@ template <typename AccDataType_, ...@@ -250,9 +250,9 @@ template <typename AccDataType_,
typename ReduceFunc, typename ReduceFunc,
typename InDataType_> typename InDataType_>
CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor, CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
sequence<InReduceDims...> in_reduce_dims, sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
const InDataType_& reduce_init) const InDataType_& reduce_init)
{ {
using InDataType = typename InDistributedTensor_::DataType; using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
...@@ -301,7 +301,10 @@ struct BlockReduce2D ...@@ -301,7 +301,10 @@ struct BlockReduce2D
.get_static_tile_distribution_encoding(), .get_static_tile_distribution_encoding(),
ReduceDim{})); ReduceDim{}));
return make_static_distributed_tensor<InDataType>(acc_dstr); auto dst_ = make_static_distributed_tensor<InDataType>(acc_dstr);
// init acc_tensor
tile_elementwise_inout([&](auto& x_) { x_ = type_convert<InDataType>(reduce_init); }, dst_);
return dst_;
} }
// return number of pixels each lane need to reduce // return number of pixels each lane need to reduce
......
...@@ -17,14 +17,24 @@ struct BlockReduce2d ...@@ -17,14 +17,24 @@ struct BlockReduce2d
CK_TILE_DEVICE constexpr BlockReduce2d() {} CK_TILE_DEVICE constexpr BlockReduce2d() {}
template <typename XDistributedTensor_, typename YDistributedTensor_, typename ReduceFunc> template <typename XDistributedTensor_,
typename YDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor, YDistributedTensor_& y_tensor,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func,
ReducePacksPerXDim = {})
{ {
sweep_tile<XDistributedTensor_>(
[&](auto... idx_) {
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
y_tensor(idx_0) = reduce_func(y_tensor(idx_0), x_tensor[idx_]...);
},
ReducePacksPerXDim{});
#if 0
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans(); constexpr auto spans = XDistributedTensor_::get_distributed_spans();
// FIXME: hard coded to reduce 2nd axis // FIXME: hard coded to reduce 2nd axis
...@@ -42,6 +52,7 @@ struct BlockReduce2d ...@@ -42,6 +52,7 @@ struct BlockReduce2d
y_tensor(y_dstr_idx) = y; y_tensor(y_dstr_idx) = y;
}); });
#endif
} }
template <typename XDistributedTensor_> template <typename XDistributedTensor_>
...@@ -63,14 +74,17 @@ struct BlockReduce2d ...@@ -63,14 +74,17 @@ struct BlockReduce2d
return tensor; return tensor;
} }
template <typename XDistributedTensor_, typename ReduceFunc> template <typename XDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor, CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
const ComputeDataType& reduce_init, const ComputeDataType& reduce_init,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func,
ReducePacksPerXDim = {})
{ {
auto y_tensor = MakeYBlockTile<XDistributedTensor_>(); auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
set_tile(y_tensor, reduce_init); set_tile(y_tensor, reduce_init);
(*this)(x_tensor, y_tensor, reduce_func); (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
return y_tensor; return y_tensor;
} }
......
...@@ -9,4 +9,5 @@ ...@@ -9,4 +9,5 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -5,4 +5,5 @@ ...@@ -5,4 +5,5 @@
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -5,4 +5,5 @@ ...@@ -5,4 +5,5 @@
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment