Commit d3689b06 authored by Adam Osewski's avatar Adam Osewski
Browse files

Do not use ck_tile:: within ck_tile namespace.

parent 6cdce153
...@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>, BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBGmemCRegV1DefaultPolicy>; BlockGemmARegBGmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return sizeof(BDataType) * return sizeof(BDataType) *
Policy::template MakeBSmemBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBSmemBlockDescriptor<Problem>().get_element_space_size();
......
...@@ -9,22 +9,21 @@ namespace ck_tile { ...@@ -9,22 +9,21 @@ namespace ck_tile {
template <typename BlockGemmShape_> template <typename BlockGemmShape_>
struct GemmTilePartitioner struct GemmTilePartitioner
{ {
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr ck_tile::index_t kM = BlockGemmShape::kM; static constexpr index_t kM = BlockGemmShape::kM;
static constexpr ck_tile::index_t kN = BlockGemmShape::kN; static constexpr index_t kN = BlockGemmShape::kN;
static constexpr ck_tile::index_t kK = BlockGemmShape::kK; static constexpr index_t kK = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size)
GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
{ {
ck_tile::index_t GridDimX = (M + kM - 1) / kM; index_t GridDimX = (M + kM - 1) / kM;
ck_tile::index_t GridDimY = (N + kN - 1) / kN; index_t GridDimY = (N + kN - 1) / kN;
ck_tile::index_t GridDimZ = batch_size; index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ); return dim3(GridDimX, GridDimY, GridDimZ);
} }
CK_TILE_DEVICE static constexpr auto GetLoopNum(index_t K) CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K)
{ {
return integer_divide_ceil(K, kK); return integer_divide_ceil(K, kK);
} }
...@@ -33,7 +32,7 @@ struct GemmTilePartitioner ...@@ -33,7 +32,7 @@ struct GemmTilePartitioner
{ {
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN);
return ck_tile::make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,7 +19,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -19,7 +19,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
...@@ -33,9 +37,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -33,9 +37,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadC = Problem::kPadC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return integer_divide_ceil(
sizeof(ADataType) * sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(), Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 16) *
...@@ -44,7 +48,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -44,7 +48,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -71,8 +71,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -71,8 +71,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
...@@ -93,7 +91,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -93,7 +91,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size();
...@@ -101,7 +99,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -101,7 +99,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
...@@ -109,7 +107,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -109,7 +107,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -25,9 +25,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV2 ...@@ -25,9 +25,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return integer_divide_ceil(
sizeof(ADataType) * sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(), Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 16) *
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -39,9 +39,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -39,9 +39,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -52,8 +52,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -52,8 +52,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -90,9 +90,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -90,9 +90,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -103,8 +103,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -103,8 +103,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -154,9 +154,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -154,9 +154,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
0); 0);
}); });
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -181,8 +181,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -181,8 +181,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
}); });
return c_vec; return c_vec;
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -231,9 +231,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -231,9 +231,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
0); 0);
}); });
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -258,8 +258,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -258,8 +258,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}); });
return c_vec; return c_vec;
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
...@@ -320,9 +320,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -320,9 +320,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
}); });
#else #else
ck_tile::ignore = c_vec; ignore = c_vec;
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
#endif #endif
} }
...@@ -356,8 +356,8 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -356,8 +356,8 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}); });
return c_vec; return c_vec;
#else #else
ck_tile::ignore = a_vec; ignore = a_vec;
ck_tile::ignore = b_vec; ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher; ...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off // clang-format off
// fp16 // fp16
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16 // bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8 // fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// clang-format on // clang-format on
} // namespace impl } // namespace impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment