Unverified Commit f1e53807 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into ck_host_lib

parents 7450417d d9f1ead3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct BatchedTransposeHostArgs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
// index_t dim_blocks;
index_t dim_stride;
index_t dim_block_h;
index_t dim_block_w;
};
template <typename Pipeline_>
struct BatchedTransposeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::InputType;
struct BatchedTransposeKargs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
index_t dim_stride;
};
using Kargs = BatchedTransposeKargs;
using Hargs = BatchedTransposeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
size_t grid_size_z = h.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.batch = h.batch;
k.height = h.height;
k.width = h.width;
k.dim_stride = h.dim_stride;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
static_assert(kMPerThread == 1 && kNPerThread == 1);
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
make_tuple(kargs.height, kargs.width),
make_tuple(kargs.width, 1),
number<kNPerThread>{}, // TODO thread load value
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
make_tuple(kargs.width, kargs.height),
make_tuple(kargs.height, 1),
number<kMPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
sequence<kPadN, kPadM>{});
}();
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
auto y_block_window =
make_tile_window(y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
Pipeline{}(x_block_window, y_block_window);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
struct BatchedTransposePipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using InputType = ck_tile::remove_cvref_t<typename Problem::InputType>;
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t AlignmentM = Problem::AlignmentM;
static constexpr index_t AlignmentN = Problem::AlignmentN;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
template <typename InputWindow, typename OutputWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
{
auto inp_win =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto out_win =
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
auto x = load_tile(inp_win); // x->thread input_win->block
auto y = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx1, idx0);
y(i_j_idx) = x(i_j_idx);
});
});
store_tile(out_win, y);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct BatchedTransposePolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2, 1>,
sequence<2, 2>>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
#define VectorLoadSize 16
namespace ck_tile {
template <typename InputType_,
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile, // Sequence<...
bool kPadM_ = true,
bool kPadN_ = true>
struct BatchedTransposeProblem
{
using InputType = remove_cvref_t<InputType_>;
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
static constexpr index_t kBlockSize =
kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -572,6 +572,105 @@ struct FastGelu ...@@ -572,6 +572,105 @@ struct FastGelu
} }
}; };
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"s_nop 0 ; hazard for rcp \n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
template <>
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u0 = x.x * (c1 * x.x * x.x + c2);
const float emu0 = exp(u0);
y.x = x.x / (1.f + emu0);
const float u1 = x.y * (c1 * x.y * x.y + c2);
const float emu1 = exp(u1);
y.y = x.y / (1.f + emu1);
}
// this is packed verion to remove data hazard for trans
template <>
CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp0, tmp1;
float y0 = x.x, y1 = x.y;
asm volatile(
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
: [v_y0] "+v"(y0),
[v_y1] "+v"(y1),
[v_c2] "+v"(c2),
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
// tmp variables we need to expicitly hint compiler they may read+write, to allow
// allocate different register , the side effect is c2=** may issue for every such
// inline asm block
[v_tmp0] "+v"(tmp0),
[v_tmp1] "+v"(tmp1)
: [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
:);
y.x = y0;
y.y = y1;
}
};
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2))) // y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu struct Gelu
...@@ -620,8 +719,83 @@ struct Silu ...@@ -620,8 +719,83 @@ struct Silu
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x))); y = x * (one / (one + ck_tile::exp(-x)));
}; };
template <>
CK_TILE_HOST_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
constexpr auto one = type_convert<float>(1);
y[0] = x[0] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[0]));
y[1] = x[1] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[1]));
};
}; };
#if 0
// Silu, the formular is not so good to do inline asm (dependency)
// we put the code here purposely if in the future ppl want to try
struct SiluAsm
{
template <typename T>
CK_TILE_HOST void operator()(T& y, T& x) const
{
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
template <typename T>
CK_TILE_DEVICE void operator()(T& y, T& x) const
{
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
// NOTE: x/y can't be same register before inline asm
// "+v" as y, "v" as x is not enought, x/y stil maybe put to same register
T tmp = x;
asm volatile("v_mul_f32 %[v_y], %[s_log2e], %[v_x]\n"
"v_exp_f32 %[v_y], %[v_y]\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_y], %[v_y], 1.0\n"
"v_rcp_f32 %[v_y], %[v_y]\n"
"s_nop 0 ; hazard for rcp\n"
"v_mul_f32 %[v_y], %[v_x], %[v_y]\n"
: [v_y] "+v"(y), [v_x] "+v"(tmp)
: [s_log2e] "s"(log2e_neg_)
:);
};
template <>
CK_TILE_HOST void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
{
constexpr auto one = type_convert<float>(1);
y[0] = x[0] * (one / (one + ck_tile::exp(-x[0])));
y[1] = x[1] * (one / (one + ck_tile::exp(-x[1])));
};
template <>
CK_TILE_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
{
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
// NOTE: x/y can't be same register before inline asm
// float tmp0 = x[0], tmp1 = x[1];
asm volatile("v_mul_f32 %[v_y0], %[s_log2e], %[v_x0]\n"
"v_mul_f32 %[v_y1], %[s_log2e], %[v_x1]\n"
"v_exp_f32 %[v_y0], %[v_y0]\n"
"v_exp_f32 %[v_y1], %[v_y1]\n"
"v_add_f32 %[v_y0], %[v_y0], 1.0\n"
"v_add_f32 %[v_y1], %[v_y1], 1.0\n"
"v_rcp_f32 %[v_y0], %[v_y0]\n"
"v_rcp_f32 %[v_y1], %[v_y1]\n"
"v_mul_f32 %[v_y0], %[v_x0], %[v_y0]\n"
"v_mul_f32 %[v_y1], %[v_x1], %[v_y1]\n"
: [v_y0] "+v"(y[0]), [v_y1] "+v"(y[1]), [v_x0] "+v"(x[0]), [v_x1] "+v"(x[1])
: [s_log2e] "s"(log2e_neg_)
:);
};
};
#endif
struct TanH struct TanH
{ {
template <typename T> template <typename T>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#define CK_TILE_MAX_RANK 5 #include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_, template <typename AccDataType_,
typename ODataType_, typename ODataType_,
bool kPadM_, typename CLayout_,
bool kPadN_, index_t kBlockSize_,
bool kTilePermute_, index_t kM_,
index_t kRank_, index_t kN_,
index_t kPerm0, index_t kMWave_,
index_t kPerm1, index_t kNWave_,
index_t TileSize0, index_t kMPerXdl_,
index_t TileSize1, index_t kNPerXdl_,
index_t kPerm2 = 0, index_t kKPerXdl_,
index_t kPerm3 = 0, bool isCTransposed_>
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem struct CShuffleEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_; using CLayout = remove_cvref_t<CLayout_>;
static constexpr bool kPadN = kPadN_; static constexpr index_t kBlockSize = kBlockSize_;
static constexpr bool kTilePermute = kTilePermute_; static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kRank = kRank_; static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; static constexpr index_t kMWave = kMWave_;
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { static constexpr index_t kNWave = kNWave_;
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue struct CShuffleEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM; using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr bool kPadN = Problem::kPadN; static constexpr index_t kBlockSize = Problem::kBlockSize;
const index_t* kPerm = Problem::kPerm; static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr bool kTilePermute = Problem::kTilePermute; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t kRank = Problem::kRank; static constexpr index_t kMWave = Problem::kMWave;
const index_t* tile_sizes = Problem::tile_sizes; static constexpr index_t kNWave = Problem::kNWave;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
// No additional shared memory needed static constexpr index_t kNPerXdl = Problem::kNPerXdl;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
template <typename OAccTile> static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
using DataType = typename OAccTile::DataType; constexpr index_t MaxVectorStoreSize = 16;
return MaxVectorStoreSize / sizeof(ODataType);
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
} }
template <typename ODramWindowTmp, typename OAccTile> template <typename Problem>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{ {
const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); // N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{ {
permuted_tile_coords[i] = tile_coords[kPerm[i]]; return make_naive_tensor_descriptor(
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]); make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
} }
// M is contiguous dimension
// Compute the permuted window origin else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{ {
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; return make_naive_tensor_descriptor(
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
} }
else
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{ {
step[i] = permuted_window_origin[i] - current_window_origin[i]; static_assert(false, "Unsupported CLayout!");
} }
}
// Move the window CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
move_tile_window(o_dram_window_tmp, step); {
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
}
// Permute the data within the tile if necessary template <typename ODramWindow,
if constexpr(kTilePermute) typename OAccTile,
{ memory_operation_enum out_memory_data_op = memory_operation_enum::set>
permute_tile_data(o_acc_tile); CK_TILE_DEVICE auto
} operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
{
// Store the tile data to the permuted location const index_t iMWarp = get_warp_id() / kNWave;
if constexpr(kPadM || kPadN) const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
buffer_store_fence(); auto o_lds_block = make_tensor_view<address_space_enum::lds>(
} static_cast<ODataType*>(p_smem), lds_block_desc);
else auto in_lds_window =
{ make_tile_window(o_lds_block,
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
} {number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
auto out_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
constexpr index_t num_access = SFC::get_num_of_access();
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC<ODataType>(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
CWarpTensor c_warp_in_tensor;
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
}
});
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem ...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseRawStore = UseRawStore_;
}; };
template <typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
};
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue struct Default2DEpilogue
{ {
...@@ -37,20 +59,109 @@ struct Default2DEpilogue ...@@ -37,20 +59,109 @@ struct Default2DEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp,
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{ {
// TODO: this is ugly // TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN)) if constexpr(UseRawStore && (kPadM || kPadN))
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
} }
} }
}; };
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -24,19 +24,19 @@ struct DynamicQuantEpilogueTraits ...@@ -24,19 +24,19 @@ struct DynamicQuantEpilogueTraits
// this epilogue just store out a M*N matrix, row major // this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, template <typename AccDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename ODataType_, typename ODataType_,
typename BlockShape_, typename BlockShape_,
typename Traits_> typename Traits_>
struct DynamicQuantEpilogueProblem struct DynamicQuantEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
}; };
// TODO: we should put descriptor creation function into policy // TODO: we should put descriptor creation function into policy
...@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue ...@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>; using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
...@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue ...@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue
#if 0 #if 0
// don't remove this // don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail // Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length) // TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>, sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
...@@ -105,34 +105,18 @@ struct DynamicQuantEpilogue ...@@ -105,34 +105,18 @@ struct DynamicQuantEpilogue
return reduce_crosswarp_sync.GetSmemSize(); return reduce_crosswarp_sync.GetSmemSize();
} }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
// how do we fix this ? CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
template <typename ODramWindowTmp, YScaleWindow& y_scale_window,
typename XScaleWindow, const OAccTile& o_acc_tile,
typename YScaleWindow, void* smem)
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{ {
auto reduce = GetBlockReduce2d(); auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync(); auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
const auto x_scale_window =
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
auto x_scale = load_tile(x_scale_window);
auto o_acc_tmp = o_acc_tile; auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto row_absmax = [&]() { auto row_absmax = [&]() {
...@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue ...@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
} }
} }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// Smooth Dynamic Quant
template <typename ODramWindowTmp,
typename SmoothScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
const auto sm_scale_window =
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
auto sm_scale = load_tile(sm_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
}
// Dynamic Quant
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
}
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
//
// for this gemm, 4 16x16x16 transposed layout
// input A vpgpr layout
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
// v16-v31: [16:31](gemm_m)x128(gemm_k)
// input B vpgpr layout
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
// ......................
// v111-v127: [448:463](gemm_n)x128(gemm_k)
// output C vpgpr layout
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
// ......................
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 512;
static constexpr index_t Block_K = 128;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t NumWarps = 4;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
static constexpr index_t BlockSize = 256;
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// load from LDS to register, every wave has same layout
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KPad = KPack_; // pad between warps
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
static_assert(KPack_ == (kABKPerLane * kKIter));
constexpr auto lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
number<kAMLane>{}, // m1 p
number<Repeat_K>{}, // k0 y
number<kABKLane>{}, // k1 p
number<KPack_>{}), // k2 y-vector
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
number<Block_K + KPad>{}, // m1
number<kABKLane * KPack_>{}, // k0
number<KPack_>{}, // k1
number<1>{}), // k2
number<KPack_>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(
make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
static constexpr auto GetGemm_AWarpEnc()
{
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
using enc_ = tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
return enc_{};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// return 32 * (128 + 8) * sizeof(bf16_t);
return MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t) * 2; // 2 lds buffers
}
};
// clang-format off
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
[s_loop_cnt]"+s"(loop_cnt), \
[v_acc_0]"+v"(v_acc[0]), \
[v_acc_1]"+v"(v_acc[1]), \
[v_acc_2]"+v"(v_acc[2]), \
[v_acc_3]"+v"(v_acc[3]), \
[v_acc_4]"+v"(v_acc[4]), \
[v_acc_5]"+v"(v_acc[5]), \
[v_acc_6]"+v"(v_acc[6]), \
[v_acc_7]"+v"(v_acc[7]), \
[v_acc_8]"+v"(v_acc[8]), \
[v_acc_9]"+v"(v_acc[9]), \
[v_acc_10]"+v"(v_acc[10]), \
[v_acc_11]"+v"(v_acc[11]), \
[v_acc_12]"+v"(v_acc[12]), \
[v_acc_13]"+v"(v_acc[13]), \
[v_acc_14]"+v"(v_acc[14]), \
[v_acc_15]"+v"(v_acc[15]), \
[s_mem_]"+r"(smem)
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
[s_loop_cnt]"+s"(loop_cnt), \
[v_acc_0]"+v"(v_acc[0]), \
[v_acc_1]"+v"(v_acc[1]), \
[v_acc_2]"+v"(v_acc[2]), \
[v_acc_3]"+v"(v_acc[3]), \
[v_acc_4]"+v"(v_acc[4]), \
[v_acc_5]"+v"(v_acc[5]), \
[v_acc_6]"+v"(v_acc[6]), \
[v_acc_7]"+v"(v_acc[7]), \
[v_acc_8]"+v"(v_acc[8]), \
[v_acc_9]"+v"(v_acc[9]), \
[v_acc_10]"+v"(v_acc[10]), \
[v_acc_11]"+v"(v_acc[11]), \
[v_acc_12]"+v"(v_acc[12]), \
[v_acc_13]"+v"(v_acc[13]), \
[v_acc_14]"+v"(v_acc[14]), \
[v_acc_15]"+v"(v_acc[15]), \
[v_acc_16]"+v"(v_acc[16]), \
[v_acc_17]"+v"(v_acc[17]), \
[v_acc_18]"+v"(v_acc[18]), \
[v_acc_19]"+v"(v_acc[19]), \
[v_acc_20]"+v"(v_acc[20]), \
[v_acc_21]"+v"(v_acc[21]), \
[v_acc_22]"+v"(v_acc[22]), \
[v_acc_23]"+v"(v_acc[23]), \
[v_acc_24]"+v"(v_acc[24]), \
[v_acc_25]"+v"(v_acc[25]), \
[v_acc_26]"+v"(v_acc[26]), \
[v_acc_27]"+v"(v_acc[27]), \
[v_acc_28]"+v"(v_acc[28]), \
[v_acc_29]"+v"(v_acc[29]), \
[v_acc_30]"+v"(v_acc[30]), \
[v_acc_31]"+v"(v_acc[31]), \
[s_mem_]"+r"(smem)
#define _EXPAND_ASM_ARGS_IN \
[s_res_a0]"s"(res_a[0]), \
[s_res_a1]"s"(res_a[1]), \
[s_res_a2]"s"(res_a[2]), \
[s_res_a3]"s"(res_a[3]), \
[s_res_b0]"s"(res_b[0]), \
[s_res_b1]"s"(res_b[1]), \
[s_res_b2]"s"(res_b[2]), \
[s_res_b3]"s"(res_b[3]), \
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
\
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
\
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
[s_m0_init]"s"(m0_init_value), \
[s_size_per_issue]"s"(size_per_issue), \
[smem_sz]"n"(smem_buf_size), \
[sld_os_0]"n"(sld_os[number<0>{}].value), \
[sld_os_1]"n"(sld_os[number<1>{}].value), \
[sld_os_2]"n"(sld_os[number<2>{}].value), \
[sld_os_3]"n"(sld_os[number<3>{}].value), \
[sld_os_4]"n"(sld_os[number<4>{}].value), \
[sld_os_5]"n"(sld_os[number<5>{}].value), \
[sld_os_6]"n"(sld_os[number<6>{}].value), \
[sld_os_7]"n"(sld_os[number<7>{}].value), \
[s_tile_os_a]"s"(tile_offset_a_bytes), \
[s_tile_os_b]"s"(tile_offset_b_bytes)
#define _EXPAND_ASM_ARGS_CLOBBER \
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
"a252", "a253", "a254", "a255", \
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
"s86", \
"v64", "v65", "v66", "v67", "v68", "v69", \
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
"v124", "v125", "v126", "v127"
// clang-format on
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = bf16_t;
using BDataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
// we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b,
bool_constant<Is2B> = {}) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
if constexpr(Is2B)
{
// this is the acc thread buffer
fp32x4_t v_acc[32]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#define CK_TILE_FLATMM_UK_2B 1
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
: _EXPAND_ASM_ARGS_OUT_TWO_ACC
: _EXPAND_ASM_ARGS_IN,
[s_res_b4]"s"(res_b[4]),
[s_res_b5]"s"(res_b[5]),
[s_res_b6]"s"(res_b[6]),
[s_res_b7]"s"(res_b[7])
: _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = make_tuple(MakeCBlockTile(), MakeCBlockTile());
for(auto i = 0; i < 16; i++)
{
c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
for(auto i = 0; i < 16; i++)
{
c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
}
return c;
}
else
{
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
: _EXPAND_ASM_ARGS_OUT_ONE_ACC
: _EXPAND_ASM_ARGS_IN
: _EXPAND_ASM_ARGS_CLOBBER
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = fp16_t;
using BDataType = fp16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b, // for each tile, the offset to move for each unroll
bool_constant<Is2B> = {})
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
if constexpr(Is2B)
{
// this is the acc thread buffer
fp32x4_t v_acc[32]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#define CK_TILE_FLATMM_UK_2B 1
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
: _EXPAND_ASM_ARGS_OUT_TWO_ACC
: _EXPAND_ASM_ARGS_IN,
[s_res_b4]"s"(res_b[4]),
[s_res_b5]"s"(res_b[5]),
[s_res_b6]"s"(res_b[6]),
[s_res_b7]"s"(res_b[7])
: _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = make_tuple(MakeCBlockTile(), MakeCBlockTile());
for(auto i = 0; i < 16; i++)
{
c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
for(auto i = 0; i < 16; i++)
{
c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
}
return c;
}
else
{
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
: _EXPAND_ASM_ARGS_OUT_ONE_ACC
: _EXPAND_ASM_ARGS_IN
: _EXPAND_ASM_ARGS_CLOBBER
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
}
};
#undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
#undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
#undef _EXPAND_ASM_ARGS_IN
#undef _EXPAND_ASM_ARGS_CLOBBER
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment