Unverified Commit 171b9030 authored by Mirza Halilčević's avatar Mirza Halilčević Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 417f805f da0c21f6
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,27 +19,27 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -19,27 +19,27 @@ struct GemmPipelineAGmemBGmemCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t AlignmentA = Problem::AlignmentA; static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t AlignmentB = Problem::AlignmentB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t AlignmentC = Problem::AlignmentC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
using LayoutA = remove_cvref_t<typename Problem::LayoutA>; static constexpr bool kPadM = Problem::kPadM;
using LayoutB = remove_cvref_t<typename Problem::LayoutB>; static constexpr bool kPadN = Problem::kPadN;
using LayoutC = remove_cvref_t<typename Problem::LayoutC>; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return integer_divide_ceil(
sizeof(ADataType) * sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(), Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 16) *
...@@ -48,7 +48,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -48,7 +48,7 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window = make_tile_window(
make_tile_window(a_lds_block, a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load // B DRAM tile window for load
auto b_copy_dram_window = auto b_copy_dram_window =
...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window = make_tile_window(
make_tile_window(b_lds_block, b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM // A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( auto a_lds_gemm_window = make_tile_window(
...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
store_tile(a_copy_lds_window, a_block_tile_tmp); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0 // LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, a_block_tile_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1 // LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--; iCounter--;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
#if 0 #if 0
// 2d // 2d
template <typename Problem> template <typename Problem>
...@@ -71,8 +72,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -71,8 +72,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
...@@ -93,7 +92,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -93,7 +92,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size();
...@@ -101,7 +100,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -101,7 +100,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
...@@ -109,7 +108,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -109,7 +108,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
...@@ -118,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -118,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return smem_size; return smem_size;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
#elif 1 #elif 1
// fake XOR // fake XOR
template <typename Problem> template <typename Problem>
...@@ -194,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -194,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t M2 = get_warp_size() / K0; {
#if 1 // coalesce reading for each blocks constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = MPerBlock / M1;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_static_tile_distribution( static_assert(KPack % K3 == 0);
tile_distribution_encoding<sequence<1>, constexpr index_t K2 = KPack / K3;
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, if constexpr(get_warp_size() % (K2 * M0))
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
sequence<1, 2>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<0, 1>>{}); static_assert(KPerBlock == K0 * K1 * K2 * K3);
#else // coalesce reading for each warps return make_static_tile_distribution(
constexpr index_t M0 = kBlockSize / get_warp_size(); tile_distribution_encoding<sequence<1>,
constexpr index_t M1 = kMPerBlock / (M2 * M0); tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
return make_static_tile_distribution( tuple<sequence<0>, sequence<1, 0, 2>>,
tile_distribution_encoding<sequence<1>, sequence<2, 1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, sequence<3, 1>>{});
tuple<sequence<1>, sequence<1, 2>>, }
tuple<sequence<0>, sequence<2, 0>>, else
sequence<1, 2>, {
sequence<1, 1>>{}); constexpr index_t K1 = (K2 * M0) / get_warp_size();
#endif constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType); constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
#if 1 // coalesce reading for each blocks static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = kNPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = kBlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
#else // coalesce reading for each warps tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
constexpr index_t N0 = kBlockSize / get_warp_size(); tuple<sequence<2>, sequence<2, 1, 2>>,
constexpr index_t N1 = kNPerBlock / (N2 * N0); tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
return make_static_tile_distribution( sequence<1, 3>>{});
tile_distribution_encoding<sequence<1>, }
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, else
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<0>, sequence<2, 0>>, constexpr index_t K1 = (K2 * N0) / get_warp_size();
sequence<1, 2>, constexpr index_t K2_m = K2 / K1;
sequence<1, 1>>{}); constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
#endif static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -25,9 +25,9 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -25,9 +25,9 @@ struct GemmPipelineAGmemBGmemCRegV2
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return integer_divide_ceil(
sizeof(ADataType) * sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(), Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 16) *
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#define VectorLoadSize 16
namespace ck_tile { namespace ck_tile {
...@@ -14,26 +12,141 @@ template <typename ADataType_, ...@@ -14,26 +12,141 @@ template <typename ADataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename TileGemmTraits_>
struct GemmPipelineProblem struct GemmPipelineProblemBase
{ {
using ADataType = remove_cvref_t<ADataType_>; using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
static constexpr bool kPadA = GemmTraits::kPadA; using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
static constexpr bool kPadB = GemmTraits::kPadB; using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr bool kPadC = GemmTraits::kPadC;
static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: VectorLoadSize / sizeof(BDataType);
}
else
{
return VectorLoadSize / sizeof(BDataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
{
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
using LayoutA = remove_cvref_t<typename GemmTraits::LayoutA>; static constexpr index_t VectorSizeC = []() {
using LayoutB = remove_cvref_t<typename GemmTraits::LayoutB>; if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
using LayoutC = remove_cvref_t<typename GemmTraits::LayoutC>; {
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
};
// Alias for GemmPipelineProblem
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); template <typename ADataType_,
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); typename BDataType_,
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType); typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
{
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -9,12 +9,8 @@ ...@@ -9,12 +9,8 @@
namespace ck_tile { namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr bool TransposeC = true; static constexpr bool TransposeC = true;
template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 4)
{
return (4 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 2)
{
return (2 / sizeof(DataType));
}
else
{
return 1;
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
constexpr index_t K0 = KPerBlock / K1;
constexpr auto DataTypeSize = sizeof(ADataType);
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value) constexpr auto MLdsLayer =
{ (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
: 32 * 4 / KPerBlock / sizeof(ADataType); make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( number<MPerBlock / MLdsLayer>{},
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1), number<KPack>{}),
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1)); make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( number<1>{});
a_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{}, constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
number<K0 * MLdsLayer>{})), a_lds_block_desc_0,
make_pass_through_transform(K1)), make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
make_tuple(sequence<1, 0>{}, sequence<2>{}), number<KPerBlock / KPack * MLdsLayer>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{})); make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( make_tuple(sequence<1, 0>{}, sequence<2>{}));
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})), constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}), a_lds_block_desc_permuted,
make_pass_through_transform(K1)), make_tuple(make_unmerge_transform(
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(number<KPerBlock / KPack>{}, number<MLdsLayer>{})),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
a_lds_block_desc_ak0_kMLdsLayer_m_ak1, make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod( constexpr auto a_lds_block_desc = transform_tensor_descriptor(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))), a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(make_merge_transform_v3_division_mod(
make_tuple(sequence<1>{}, sequence<0>{})); make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
return a_lds_block_desc_m_k; make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
} make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
else // ColumnMajor A make_tuple(sequence<0>{}, sequence<1>{}));
{
// kfold and mpair dimension is not always required. return a_lds_block_desc;
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kM;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=kN0
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
? M0
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
K1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1; constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value) (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
{
// NLdsLayer * K0 as logical Bank constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 make_tuple(number<KPerBlock / KPack * NLdsLayer>{},
? 1 number<NPerBlock / NLdsLayer>{},
: 32 * 4 / KPerBlock / sizeof(BDataType); number<KPack>{}),
; make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( number<KPack>{},
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1), number<1>{});
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0,
b_lds_block_desc, make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{}, number<KPerBlock / KPack * NLdsLayer>{})),
number<K0 * NLdsLayer>{})), make_pass_through_transform(number<KPack>{})),
make_pass_through_transform(K1)), make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( b_lds_block_desc_permuted,
b_lds_block_desc_permuted, make_tuple(make_unmerge_transform(
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})), make_tuple(number<KPerBlock / KPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}), make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(K1)), make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1, b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), make_tuple(make_merge_transform_v3_division_mod(
make_merge_transform_v3_division_mod( make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))), make_merge_transform_v3_division_mod(
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0>{})); make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k; return b_lds_block_desc;
}
else // RowMajor B
{
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kN;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=kN0
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
? N0
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
K1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
} }
template <typename Problem> template <typename Problem>
...@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using ADataType = remove_cvref_t<typename Problem::ADataType>;
typename Problem::BDataType, using ALayout = remove_cvref_t<typename Problem::ALayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t K0 = KPerBlock / K1; {
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t K3 = total_pixels / M1;
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(KPack % K3 == 0);
return make_static_tile_distribution( constexpr index_t K2 = KPack / K3;
tile_distribution_encoding<sequence<1>, if constexpr(get_warp_size() % (K2 * M0) == 0)
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<1, 2>, static_assert(KPerBlock == K0 * K1 * K2 * K3);
sequence<0, 1>>{}); return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using BDataType = remove_cvref_t<typename Problem::BDataType>;
typename Problem::BDataType, using BLayout = remove_cvref_t<typename Problem::BLayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0), constexpr index_t BlockSize = Problem::kBlockSize;
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2), constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
TransposeC>; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1; constexpr index_t N0 = NPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = BlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = BlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,21 +7,23 @@ ...@@ -7,21 +7,23 @@
namespace ck_tile { namespace ck_tile {
template <bool kPadA_, template <bool kPadM_,
bool kPadB_, bool kPadN_,
bool kPadC_, bool kPadK_,
typename LayoutA_, typename ALayout_,
typename LayoutB_, typename BLayout_,
typename LayoutC_> typename CLayout_>
struct TileGemmTraits struct TileGemmTraits
{ {
static constexpr bool kPadA = kPadA_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadK = kPadK_;
using LayoutA = LayoutA_; static constexpr int _VectorSize = 16;
using LayoutB = LayoutB_;
using LayoutC = LayoutC_; using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -39,9 +39,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -39,9 +39,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -52,8 +52,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -52,8 +52,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -90,9 +90,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -90,9 +90,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -103,8 +103,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -103,8 +103,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -154,9 +154,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -154,9 +154,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
0); 0);
}); });
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -181,8 +181,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -181,8 +181,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
}); });
return c_vec; return c_vec;
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -231,9 +231,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -231,9 +231,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
0); 0);
}); });
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -258,8 +258,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -258,8 +258,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}); });
return c_vec; return c_vec;
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -320,9 +320,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -320,9 +320,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
}); });
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -356,8 +356,8 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -356,8 +356,8 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}); });
return c_vec; return c_vec;
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher; ...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off // clang-format off
// fp16 // fp16
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16 // bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8 // fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// clang-format on // clang-format on
} // namespace impl } // namespace impl
......
...@@ -6,4 +6,5 @@ ...@@ -6,4 +6,5 @@
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#pragma once #pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -5,447 +5,384 @@ ...@@ -5,447 +5,384 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
namespace ck_tile { namespace ck_tile {
// host side args
struct Layernorm2dFwdHostArgs
{
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input
void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
float epsilon;
index_t m;
index_t n;
index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
};
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
template <typename Problem_> template <typename Pipeline_, typename Epilogue_>
struct Layernorm2dFwd struct Layernorm2dFwd
{ {
using Problem = ck_tile::remove_cvref_t<Problem_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using Problem = typename Pipeline::Problem;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>; using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
static constexpr bool kSaveMean = !std::is_same_v<MeanDataType, ck_tile::null_type>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kSaveInvStd = !std::is_same_v<InvStdDataType, ck_tile::null_type>;
// for simplicity, shortcut input/output type is same as X
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; using XResidualDataType = XDataType;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; using YResidualDataType = XDataType;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
struct Kargs struct Kargs
{ {
const void* p_x; const void* p_x; // [m ,n], input, fp16/bf16
const void* p_gamma; const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_beta; const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input
void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_y; void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used
void* p_mean; void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
void* p_invStd;
float epsilon; float epsilon;
ck_tile::index_t M; index_t m;
ck_tile::index_t N; index_t n;
index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
using Hargs = Layernorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x, CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
const void* p_gamma,
const void* p_beta,
void* p_y,
void* p_mean,
void* p_invStd,
float epsilon,
ck_tile::index_t M,
ck_tile::index_t N)
{
return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N};
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; }
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{ {
using S = typename Problem::BlockShape; return Kargs{hargs.p_x,
hargs.p_x_residual,
return make_static_tile_distribution( hargs.p_x_scale,
tile_distribution_encoding< hargs.p_gamma,
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>, hargs.p_beta,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>, hargs.p_y,
tuple<sequence<0, 1>, sequence<0, 1>>, hargs.p_y_residual,
tuple<sequence<0, 0>, sequence<1, 1>>, hargs.p_y_scale,
sequence<1>, hargs.p_mean,
sequence<2>>{}); hargs.p_invStd,
hargs.epsilon,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
} }
CK_TILE_DEVICE static int GetWelfordMaxCount(int N) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{ {
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; return dim3(integer_divide_ceil(hargs.m, Block_M));
int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
if(n_per_block_tail_loop > 0)
{
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return max_count;
} }
template <typename DistributedTensor> CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
const ComputeDataType epsilon)
{
// TODO: Investigate fast inverse square root algorithm with epsilon
constexpr auto spans = DistributedTensor::get_distributed_spans();
DistributedTensor out_dstr_tensor;
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { // clang-format off
constexpr auto i_idx = make_tuple(idx0); template <typename T> struct t2s;
out_dstr_tensor(i_idx) = type_convert<ComputeDataType>(1.0f) / template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon); template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
}); template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
return out_dstr_tensor; // in byte
} CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
template <typename XBlockWindow, CK_TILE_HOST static std::string GetName()
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{ {
// TODO - Optimize tail loop to reduce move_tile_window() #define _SS_ std::string
index_t num_n_tile_iteration = #define _TS_ std::to_string
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); // clang-format off
using S_ = typename Problem::BlockShape;
int welford_max_count = GetWelfordMaxCount(N); auto surfix = [&] () {
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count}; std::string n;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
using XTensorType = decltype(load_tile(x_block_window)); if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
auto mean_compute_block_tensor = if (kPadN) n += "_pn";
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>(); if (kSaveMeanInvStd) n += "_mv";
auto var_compute_block_tensor = // if (kTwoPass) n += "_2p";
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>(); return n; }();
clear_tile(mean_compute_block_tensor); auto prec_str = [&] () {
clear_tile(var_compute_block_tensor); std::string base_str = _SS_(t2s<XDataType>::name);
if (!std::is_same_v<XDataType, YDataType>) {
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
{ }
const auto x_block_tensor = load_tile(x_block_window); if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
move_tile_window(x_block_window, {0, kNPerBlock}); }
} if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
// TODO: support cross warp Welford }
WarpMergeWelford<ComputeDataType, true>{}( return base_str;
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); }();
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window});
move_tile_window(beta_block_window, {stride_to_right_most_window});
move_tile_window(y_block_window, {0, stride_to_right_most_window});
// Normalization
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {-kNPerBlock});
move_tile_window(beta_block_window, {-kNPerBlock});
move_tile_window(y_block_window, {0, -kNPerBlock});
}
}
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor); return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
// clang-format on
#undef _SS_
#undef _TS_
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
const auto x_m_n = [&]() { const auto iM = get_block_id() * Block_M;
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.M, kargs.N), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.N, 1), make_tuple(kargs.x_stride, 1),
number<kNPerThread>{}, number<Vector_N>{},
number<1>{}); number<1>{});
return pad_tensor_view(x_dram_naive, // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), // check the max count dynamically
sequence<kPadM, kPadN>{}); const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto x_residual_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.xr_stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
// will check the max count dynamically
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}(); }();
const auto gamma_n = [&]() { const auto gamma_window = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N), make_tuple(kargs.n),
make_tuple(1), make_tuple(1),
number<kNPerThread>{}, number<Vector_N>{},
number<1>{}); number<1>{});
return pad_tensor_view( const auto tmp2_ =
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{}); pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}(); }();
const auto beta_n = [&]() { const auto beta_window = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta), static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N), make_tuple(kargs.n),
make_tuple(1), make_tuple(1),
number<kNPerThread>{}, number<Vector_N>{},
number<1>{}); number<1>{});
return pad_tensor_view( const auto tmp2_ =
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{}); pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
}(); }();
const auto iM = get_block_id() * kMPerBlock; auto y_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y), static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.N, 1), make_tuple(kargs.y_stride, 1),
number<kNPerThread>{}, number<Vector_N>{},
number<1>{}); number<1>{});
return pad_tensor_view(y_dram_naive, auto tmp2_ = pad_tensor_view(
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
sequence<kPadM, kPadN>{}); return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
auto y_block_window = make_tile_window( auto y_residual_window = [&]() {
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}); if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
constexpr auto betaDstr = gammaDstr; static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n),
auto gamma_block_window = make_tuple(kargs.yr_stride, 1),
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr); number<Vector_N>{},
number<1>{});
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr); auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto mean_block_window = [&]() { auto mean_window = [&]() {
if constexpr(kSaveMean) if constexpr(kSaveMean)
{ {
const auto mean_m = [&]() { const auto mean_m = [&]() {
const auto mean_dram_naive = const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>( make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean), static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M), make_tuple(kargs.m),
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{}); mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}(); }();
return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
} }
else else
return make_null_tile_window(make_tuple(number<kMPerBlock>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto inv_std_block_window = [&]() { auto inv_std_window = [&]() {
if constexpr(kSaveInvStd) if constexpr(kSaveInvStd)
{ {
const auto inv_std_m = [&]() { const auto inv_std_m = [&]() {
const auto inv_std_dram_naive = const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>( make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd), static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M), make_tuple(kargs.m),
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{}); inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}(); }();
return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM}); auto x_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
else
return make_null_tile_window(make_tuple(number<Block_N>{}));
}();
auto y_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
} }
else else
return make_null_tile_window(make_tuple(number<kMPerBlock>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
if(kargs.N <= kNPerBlock) __shared__ char smem[GetSmemSize()];
OnePassLayernorm2dFwd(x_block_window,
gamma_block_window, Pipeline{}(x_window,
beta_block_window, x_residual_window,
y_block_window, gamma_window,
mean_block_window, beta_window,
inv_std_block_window, y_window,
static_cast<const ComputeDataType>(kargs.epsilon), y_residual_window,
kargs.N); mean_window,
else inv_std_window,
TwoPassLayernorm2dFwd(x_block_window, x_scale_window,
gamma_block_window, y_scale_window,
beta_block_window, static_cast<const ComputeDataType>(kargs.epsilon),
y_block_window, kargs.n,
mean_block_window, smem,
inv_std_block_window, Epilogue{});
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
namespace ck_tile {
struct Layernorm2dFwdPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelford<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelfordSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelfordCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using block_welford = BlockWelford<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockWelfordCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
struct Layernorm2dFwdPipelineOnePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
else
return "wpr"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow,
typename XScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
}
},
var);
if constexpr(kSaveMean)
store_tile(mean_window, cast_tile<MeanDataType>(mean));
if constexpr(kSaveInvStd)
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
});
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
}
else
Epilogue{}(y_window_, ln);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,21 +14,27 @@ template <typename XDataType_, ...@@ -14,21 +14,27 @@ template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
bool kPadM_, typename Traits_>
bool kPadN_> struct Layernorm2dFwdPipelineProblem
struct BlockLayernorm2dFwdProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using XScaleDataType = remove_cvref_t<XScaleDataType_>;
static constexpr bool kPadM = kPadM_; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
static constexpr bool kPadN = kPadN_; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
using Traits = remove_cvref_t<Traits_>;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
struct Layernorm2dFwdPipelineTwoPass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_2p"; // block per row
else
return "wpr_2p"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow,
typename XScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
// total number of count assume current iter have no pad(only last iter has pad)
constexpr index_t count_per_iter =
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
int cur_count = 0;
int max_count =
(num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(acc, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
}
},
var);
if constexpr(kSaveMean)
store_tile(mean_window, cast_tile<MeanDataType>(mean));
if constexpr(kSaveInvStd)
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
// layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
});
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
Epilogue{}(y_window, ln);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Layernorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before layernorm and store result to global
PRE_ADD_STORE = 1,
// fused add before layernorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Layernorm2dFusedAddEnum> struct Layernorm2dFusedAddEnumName;
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Layernorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ThreadTile, // Sequence<...
typename WarpTile, // Sequence<...
typename BlockTile> // Sequence<...
struct TileLayernorm2dShape
{
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
// TODO - kNNumWarps can only be 1 if we don't support cross warp welford
static_assert(kNWarpPerBlock == 1);
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
namespace ck_tile {
/* independent host side argument, no template
*/
struct GenericPermuteHostArgs
{
static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
const void* p_src;
void* p_dst;
index_t rank;
index_t shape[kMaxRanks]; // input shape
index_t perm[kMaxRanks]; // permute index
};
/*
simulate torch.permute:
x_ = x_.view(x.shape[0],
x.shape[1]//16, 16,
x.shape[2]//32, 4, 8)
x_ = x_.permute(0,1,3,4,2,5)
x_ = x_.contiguous()
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
dim of permutation, with a single kernel
*/
template <typename Problem_>
struct GenericPermute
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using DataType = remove_cvref_t<typename Problem::DataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMaxRanks = Problem::kMaxRanks;
static constexpr bool KeepLastDim = Problem::KeepLastDim;
struct __attribute__((packed)) Kargs
{
const void* p_src;
void* p_dst;
// index_t rank;
index_t num_elements;
index_t perm_length[kMaxRanks]; // tensor length after permutation
index_t perm_stride[kMaxRanks]; // tensor stride after permutation
};
CK_TILE_HOST static constexpr index_t TotalElements(const GenericPermuteHostArgs& h)
{
index_t n = 1;
for(auto i = 0; i < h.rank; i++)
{
n *= h.shape[i];
}
return n;
}
CK_TILE_HOST static constexpr Kargs MakeKargs(const GenericPermuteHostArgs& h)
{
Kargs a;
a.p_src = h.p_src;
a.p_dst = h.p_dst;
// assert rank <= kMaxRanks
index_t i = 0;
index_t perm[kMaxRanks];
index_t x_shape[kMaxRanks];
index_t x_stride[kMaxRanks];
// index_t perm_length[kMaxRanks];
for(; i < h.rank; i++)
{
x_shape[i] = h.shape[i];
perm[i] = h.perm[i];
}
for(; i < kMaxRanks; i++)
{
x_shape[i] = 1;
perm[i] = i; // will index to len = 1
}
index_t stride = 1;
for(index_t j = kMaxRanks - 1; j >= 0; j--)
{
x_stride[j] = stride;
stride *= x_shape[j];
}
for(index_t j = 0; j < kMaxRanks; j++)
{
a.perm_length[j] = x_shape[perm[j]];
a.perm_stride[j] = x_stride[perm[j]];
}
a.num_elements = TotalElements(h);
return a;
}
CK_TILE_HOST static constexpr auto GridSize(GenericPermuteHostArgs h)
{
auto total = TotalElements(h);
auto grids = dim3((total + BlockSize() - 1) / BlockSize());
// printf("### total:%d, grids:%dx%dx%d\n", total, );
return grids;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t id = blockIdx.x * BlockSize() + threadIdx.x;
if(id >= kargs.num_elements)
return;
const auto perm_length =
generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
const auto perm_stride =
generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
const auto src_view_0 = make_naive_tensor_view<address_space_enum::global>(
p_src, perm_length, perm_stride, number<1>{}, number<1>{});
const auto src_view = transform_tensor_view(
src_view_0,
make_tuple(make_merge_transform(perm_length)),
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
make_tuple(sequence<0>{}));
auto dst_view_0 = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst, perm_length, number<1>{});
auto dst_view = transform_tensor_view(
dst_view_0,
make_tuple(make_merge_transform(perm_length)),
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
make_tuple(sequence<0>{}));
// TODO: hard code to vector 1
using vector_t = thread_buffer<DataType, 1>;
const auto src_coord =
make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
const auto dst_coord =
make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment