Commit cd4d4629 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_v3

parents 21d12bb7 888317e6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace ck_tile {
struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeABXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeABXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
struct AddRmsnorm2dRdquantFwdPipelineOnePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using ADataType = ck_tile::remove_cvref_t<typename Problem::ADataType>;
using BDataType = ck_tile::remove_cvref_t<typename Problem::BDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
else
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename AWindow,
typename BWindow,
typename GammaWindow,
typename XWindow,
typename YScaleWindow,
typename QYWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const BWindow& b_window_,
const GammaWindow& gamma_window_,
XWindow& x_window,
YScaleWindow& yscale_window,
QYWindow& qy_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
{
const auto a_window =
make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
const auto b_window =
make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto a = load_tile(a_window);
const auto b = load_tile(b_window);
const auto gamma = load_tile(gamma_window);
auto x = tile_elementwise_in(
[&](const auto& a_, const auto& b_) {
return type_convert<ComputeDataType>(a_) + type_convert<ComputeDataType>(b_);
},
a,
b);
if constexpr(kSaveX)
store_tile(x_window, cast_tile<XDataType>(x));
// compute mean square, each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d(
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
// rmsnorm computation
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<ComputeDataType>(y_);
});
// compute absmax, each-thread->cross-lane->cross-warp
auto absmax = [&]() {
constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
{
return block_reduce2d(y,
reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
reduce_absmax3_func,
sequence<1, 2>{});
}
else
{
return block_reduce2d(
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
}
}();
block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
// ex: yscale = absmax / 127 if int8
auto yscale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
},
absmax);
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
// quantize y to qy
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(qy, [&, yscale_ = yscale](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale_[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_);
});
store_tile(qy_window, qy);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template <typename ADataType_,
typename BDataType_,
typename GammaDataType_,
typename ComputeDataType_,
typename XDataType_,
typename YScaleDataType_,
typename QYDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveX_,
bool kThreePass_>
struct AddRmsnorm2dRdquantFwdPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using XDataType = remove_cvref_t<XDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using QYDataType = remove_cvref_t<QYDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kThreePass = kThreePass_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
struct AddRmsnorm2dRdquantFwdPipelineThreePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using ADataType = ck_tile::remove_cvref_t<typename Problem::ADataType>;
using BDataType = ck_tile::remove_cvref_t<typename Problem::BDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_tp"; // block per row
else
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename AWindow,
typename BWindow,
typename GammaWindow,
typename XWindow,
typename YScaleWindow,
typename QYWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const BWindow& b_window_,
const GammaWindow& gamma_window_,
XWindow& x_window_,
YScaleWindow& yscale_window,
QYWindow& qy_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
{
auto a_window =
make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
auto b_window =
make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
auto x_window = [&]() {
if constexpr(kSaveX)
return make_tile_window(x_window_,
Policy::template MakeABXBlockTileDistribution<Problem>());
else
return x_window_;
}();
auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(a_window)));
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto a = load_tile(a_window);
const auto b = load_tile(b_window);
auto x = tile_elementwise_in(
[&](const auto& a_, const auto& b_) {
return type_convert<ComputeDataType>(a_) + type_convert<ComputeDataType>(b_);
},
a,
b);
if constexpr(kSaveX)
store_tile(x_window, cast_tile<XDataType>(x));
block_reduce2d(x, square_sum, reduce_square_sum_func);
move_tile_window(x_window, {0, Block_N});
move_tile_window(a_window, {0, Block_N});
move_tile_window(b_window, {0, Block_N});
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
if constexpr(kSaveX)
move_tile_window(x_window, {0, -Block_N});
else
{
move_tile_window(a_window, {0, -Block_N});
move_tile_window(b_window, {0, -Block_N});
}
move_tile_window(gamma_window, {stride_to_right_most_window});
using YTensorType = XTensorType;
auto absmax = block_reduce2d.template MakeYBlockTile<YTensorType>();
set_tile(absmax, reduce_absmax_func.GetIdentityValue<ComputeDataType>());
// rmsnorm computation + absmax(threadwise reduce)
if constexpr(kSaveX)
__syncthreads();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = [&]() {
if constexpr(kSaveX)
{
return load_tile(x_window);
}
else
{
const auto a = load_tile(a_window);
const auto b = load_tile(b_window);
return tile_elementwise_in(
[&](const auto& a_, const auto& b_) {
return type_convert<ComputeDataType>(a_) +
type_convert<ComputeDataType>(b_);
},
a,
b);
}
}();
auto gamma = load_tile(gamma_window);
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms[i_idx] * gamma_;
y(idx) = type_convert<ComputeDataType>(y_);
});
constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
else
block_reduce2d(y, absmax, reduce_absmax_func);
if constexpr(kSaveX)
move_tile_window(x_window, {0, -Block_N});
else
{
move_tile_window(a_window, {0, -Block_N});
move_tile_window(b_window, {0, -Block_N});
}
move_tile_window(gamma_window, {-Block_N});
}
// compute absmax, cross-lane->cross-warp
block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
// ex: yscale = absmax / 127 if int8
auto yscale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
},
absmax);
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
// quantize y to qy
// recompute rmsnorm, try to save y in the future
if constexpr(kSaveX)
move_tile_window(x_window, {0, Block_N});
else
{
move_tile_window(a_window, {0, Block_N});
move_tile_window(b_window, {0, Block_N});
}
move_tile_window(gamma_window, {Block_N});
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = [&]() {
if constexpr(kSaveX)
{
return load_tile(x_window);
}
else
{
const auto a = load_tile(a_window);
const auto b = load_tile(b_window);
return tile_elementwise_in(
[&](const auto& a_, const auto& b_) {
return type_convert<ComputeDataType>(a_) +
type_convert<ComputeDataType>(b_);
},
a,
b);
}
}();
auto gamma = load_tile(gamma_window);
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(y, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms[i_idx] * gamma_;
auto qy_ = y_ / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_);
});
store_tile(qy_window, qy);
if constexpr(kSaveX)
move_tile_window(x_window, {0, Block_N});
else
{
move_tile_window(a_window, {0, Block_N});
move_tile_window(b_window, {0, Block_N});
}
move_tile_window(gamma_window, {Block_N});
move_tile_window(qy_window, {0, Block_N});
}
}
};
} // namespace ck_tile
......@@ -3,4 +3,5 @@
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N>
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
struct Generic2dBlockShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M;
static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N;
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
};
} // namespace ck_tile
......@@ -3,6 +3,6 @@
#pragma once
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile {
namespace element_wise {
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct Scale
{
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = ck_tile::type_convert<Y>(ck_tile::type_convert<float>(x) * scale_);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::fp16_t>(scale_) * x;
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
const float x_tmp = ck_tile::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_tmp);
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = ck_tile::type_convert<int8_t>(scale_ * ck_tile::type_convert<float>(x));
};
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = ck_tile::isnan(x) ? -numeric<float>::infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = ck_tile::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"Data type is not supported by this operation!");
y = ck_tile::sqrt(x);
};
};
struct Relu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
float x_f32 = ck_tile::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_f32);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u);
y = x * ck_tile::rcp(1.f + emu);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
};
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"s_nop 0 ; hazard for rcp \n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
template <>
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u0 = x.x * (c1 * x.x * x.x + c2);
const float emu0 = exp(u0);
y.x = x.x / (1.f + emu0);
const float u1 = x.y * (c1 * x.y * x.y + c2);
const float emu1 = exp(u1);
y.y = x.y / (1.f + emu1);
}
// this is packed verion to remove data hazard for trans
template <>
CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp0, tmp1;
float y0 = x.x, y1 = x.y;
asm volatile(
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
: [v_y0] "+v"(y0),
[v_y1] "+v"(y1),
[v_c2] "+v"(c2),
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
// tmp variables we need to expicitly hint compiler they may read+write, to allow
// allocate different register , the side effect is c2=** may issue for every such
// inline asm block
[v_tmp0] "+v"(tmp0),
[v_tmp1] "+v"(tmp1)
: [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
:);
y.x = y0;
y.y = y1;
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::fp16_t(0.5) * x *
(ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x))));
}
};
struct Sigmoid
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
};
};
struct Silu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
};
struct TanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tanh(x);
};
};
struct ACos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acos(x);
};
};
struct Neg
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::neg(x);
};
};
struct ATan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atan(x);
};
};
struct Sin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sin(x);
};
};
struct ASinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asinh(x);
};
};
struct Cos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cos(x);
};
};
struct ACosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acosh(x);
};
};
struct Tan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tan(x);
};
};
struct ATanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atanh(x);
};
};
struct SinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sinh(x);
};
};
struct Ceil
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::ceil(x);
};
};
struct Exp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::exp(x);
};
};
struct CosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cosh(x);
};
};
struct Floor
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::floor(x);
};
};
struct Log
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::log(x);
};
};
struct ASin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asin(x);
};
};
struct Rcp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::rcp(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
std::is_same_v<Y, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck_tile::exp(bx)));
};
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck_tile::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
}
const float alpha_;
};
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha);
}
const float alpha_;
};
struct ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<ck_tile::fp8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
template <typename DstType, typename SrcType>
struct Cast
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
{
y = ck_tile::type_convert<DstType>(x);
};
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
} // namespace element_wise
} // namespace ck_tile
......@@ -5,4 +5,6 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return false;
}
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
......@@ -111,7 +118,9 @@ struct CShuffleEpilogue
}
}
template <typename ODramWindowTmp, typename OAccTile>
template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
......@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,41 +9,65 @@ namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, typename ODataType_, bool kPadM_, bool kPadN_>
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
};
template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
{
// TODO: this is ugly
if constexpr(kPadM || kPadN)
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool UseSmoothInputScale_,
bool UseRawStore_ = true,
bool UseMax3_ = false>
struct DynamicQuantEpilogueTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr bool UseMax3 = UseMax3_;
};
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
typename ODataType_,
typename BlockShape_,
typename Traits_>
struct DynamicQuantEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>;
};
// TODO: we should put descriptor creation function into policy
template <typename Problem_, typename Policy_ = void>
struct DynamicQuantEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
static constexpr bool kPadM = Problem::Traits::kPadM;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
static constexpr bool UseMax3 = Problem::Traits::UseMax3;
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2d<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dSync<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
{
using S = BlockShape;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
#endif
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
return reduce_crosswarp_sync.GetSmemSize();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp,
typename XScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
const auto x_scale_window =
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
auto x_scale = load_tile(x_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto row_absmax = [&]() {
constexpr auto y_size_per_row =
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
number<1>{});
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
{
// fast max3+abs implementation
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
}
else
{
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
}
}();
reduce_sync(row_absmax, f_absmax);
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
// here y_scale is Acc TYpe, need convert to YScale type later
auto y_scale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
},
row_absmax);
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
});
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
//
// for this gemm, 4 16x16x16 transposed layout
// input A vpgpr layout
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
// v16-v31: [16:31](gemm_m)x128(gemm_k)
// input B vpgpr layout
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
// ......................
// v111-v127: [448:463](gemm_n)x128(gemm_k)
// output C vpgpr layout
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
// ......................
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 512;
static constexpr index_t Block_K = 128;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t NumWarps = 4;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
static constexpr index_t BlockSize = 256;
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// load from LDS to register, every wave has same layout
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KPad = KPack_; // pad between warps
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
static_assert(KPack_ == (kABKPerLane * kKIter));
constexpr auto lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
number<kAMLane>{}, // m1 p
number<Repeat_K>{}, // k0 y
number<kABKLane>{}, // k1 p
number<KPack_>{}), // k2 y-vector
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
number<Block_K + KPad>{}, // m1
number<kABKLane * KPack_>{}, // k0
number<KPack_>{}, // k1
number<1>{}), // k2
number<KPack_>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(
make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
static constexpr auto GetGemm_AWarpEnc()
{
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
using enc_ = tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
return enc_{};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 32 * (128 + 8) * sizeof(bf16_t);
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = bf16_t;
using BDataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = fp16_t;
using BDataType = fp16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 128;
static constexpr index_t Block_K = 512;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32;
static constexpr index_t BlockSize = 256;
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
namespace ck_tile {
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
// index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
// [s_loop_cnt]"+s"(loop_cnt),
[s_loop_cnt]"+s"(n),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37","s59","s80",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
// index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(n),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37","s59","s80",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
#define CK_TILE_FLATMM_UK_MFMA_BF8 4
the files under this folder should not be included directly!
\ No newline at end of file
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
" v_perm_b32 " y_ ", v55, v54, s52 \n"
#define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, " x0_ " \n" \
" v_cvt_f16_f32 v55, " x1_ " \n" \
" v_pack_b32_f16 " y_ ", v54, v55 \n"
#define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b64 s[38:39], exec ; save current exec\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_waitcnt 0 \n"
"L_start%=: \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[30:31], "
"v[206:207], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[62:63], "
"v[222:223], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], "
"v[238:239], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], "
"%[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], "
"%[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], "
"%[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], "
"%[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], "
"%[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], "
"%[c13], %[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], "
"%[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], "
"%[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], "
"%[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], "
"%[c13], %[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], "
"%[c15]]\n"
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
" v_mul_f32 %[c15], %[scale_1], %[c15] \n" _UK_PK_CVT_(
"%[c0]", "%[c1]", "%[c0]") _UK_PK_CVT_("%[c2]", "%[c3]", "%[c1]")
_UK_PK_CVT_("%[c4]", "%[c5]", "%[c2]") _UK_PK_CVT_("%[c6]", "%[c7]", "%[c3]") _UK_PK_CVT_(
"%[c8]", "%[c9]", "%[c4]") _UK_PK_CVT_("%[c10]", "%[c11]", "%[c5]")
_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]") _UK_PK_CVT_(
"%[c14]",
"%[c15]",
"%[c7]") " ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] "
" \n"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] "
" \n"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] "
" \n"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] "
" \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] "
" \n"
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] "
" \n"
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] "
" \n"
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] "
" \n"
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] "
" \n"
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] "
" \n"
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] "
" \n"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] "
" \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o0], %[c0], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o1], %[c1], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o2], %[c2], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o3], %[c3], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o4], %[c4], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o5], %[c5], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o6], %[c6], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o7], %[c7], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[130:131], "
"v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[146:147], "
"v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[162:163], "
"v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[178:179], "
"v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[194:195], "
"v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[210:211], "
"v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[226:227], "
"v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[242:243], "
"v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], "
"[%[c28],%[c29],%[c30],%[c31]]\n"
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
" v_mul_f32 %[c31], %[scale_1], %[c31] \n"
_UK_PK_CVT_("%[c16]", "%[c17]", "%[c16]") _UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]") _UK_PK_CVT_(
"%[c20]", "%[c21]", "%[c18]") _UK_PK_CVT_("%[c22]", "%[c23]", "%[c19]")
_UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]") _UK_PK_CVT_(
"%[c26]", "%[c27]", "%[c21]") _UK_PK_CVT_("%[c28]",
"%[c29]",
"%[c22]") _UK_PK_CVT_("%[c30]",
"%[c31]",
"%[c23]")
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n" _UK_ATOMIC_ADD_
" %[v_os_o0], %[c16], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n" _UK_ATOMIC_ADD_
" %[v_os_o1], %[c17], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n" _UK_ATOMIC_ADD_
" %[v_os_o2], %[c18], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_
" %[v_os_o3], %[c19], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_
" %[v_os_o4], %[c20], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_
" %[v_os_o5], %[c21], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_
" %[v_os_o6], %[c22], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_
" %[v_os_o7], %[c23], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment