"vscode:/vscode.git/clone" did not exist on "51361cb66b8e2a0f750091ba8e5ce96da2c8a8f3"
Unverified Commit 37a8c1f7 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski Committed by GitHub
Browse files

Redesign the DPP8 GEMM kernel to use warp-wise component (#863)

* Redesign the DPP8 GEMM kernel to use warp-wise component

* Review: Improve error messages

* Review: Remove unnecessary empty lines

* Review: Fix M, N per thread names

* Review: Rename mfma_input_type to dpp_input_type

* Review: Fix tensor adaptor; remove unnecessary element

* Review: Remove calls to dpp_gemm's MakeCDescriptor

* Review: Add blockwise doc, change function names to include dimension names

* Review: Remove duplicated code; Move Block2CtileMap alias to the top of the file

* Review: Add __restrict__ keywords

* Review: Use MatrixPadder for padding A, B, C matrices

* Review: Remove hardcoded datatypes

* Review: Change names from FloatX to XDataType

* Review: Introduce AK0 and BK0 instead of a single K0

* Review: Remove construction of dpp_datatypes object

* Review: Rename DppInstrRunner to DppLanegroupGemm
parent 3786bfe1
...@@ -6,8 +6,7 @@ if(DL_KERNELS) ...@@ -6,8 +6,7 @@ if(DL_KERNELS)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_fp16) add_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
......
...@@ -3,31 +3,33 @@ ...@@ -3,31 +3,33 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp"
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using ALayout = Col; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>;
// clang-format on // // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
namespace ck {
/**
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
* the data loaded from LDS to registers.
*
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
* between them in such a way that threads from the same group need the same chunk of either
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
* whole chunk from LDS to its own register space.
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
* of the chunk, and then share that data with other threads from the same lane group.
*
* Assumptions coming from the usage of DPP8:
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
* matrix A or B;
* - based on these values we determine which matrix to share.
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
* - we have to make sure that the data to split is divisible by the number of
* threads in the group.
*
* General algorithm:
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
* A and B are visible to the whole block, C is distributed among each thread
* Assume:
* 1. A:
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
* 2. ABlockBuffer is DynamicBuffer
* 2. B:
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
* 2. BBlockBuffer is DynamicBuffer
* 3. C:
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
* 2. CThreadBuffer is StaticBuffer
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
*/
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
{
using AIndex = MultiIndex<4>;
using BIndex = MultiIndex<4>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t BM1 = BM100 * BM101 * BM11;
static constexpr index_t BN1 = BN100 * BN101 * BN11;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
// threads in a lane group need the same chunk of B or A matrices and we can share them using
// DPP.
static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size);
static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false;
static constexpr bool ShareA = !ShareB;
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
// respectively) elements, so we split them between threads in lane group so each thread loads
// less data from LDS.
static constexpr index_t BM1PerThread =
ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11;
static constexpr index_t BN1PerThread =
ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11;
__host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()},
b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()}
{
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
BM10BN10ThreadClusterBN10Xs::Size() == 2,
"wrong!");
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
__device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()
{
const auto offsetBM0 = c_thread_origin_data_idx_[I0];
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] +
dpp8::get_thread_idx_in_lane_group() * BM1PerThread
: c_thread_origin_data_idx_[I1];
return make_tuple(0, offsetBM0, offsetBM1, 0);
}
__device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()
{
const auto offsetBN0 = c_thread_origin_data_idx_[I2];
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] +
dpp8::get_thread_idx_in_lane_group() * BN1PerThread
: c_thread_origin_data_idx_[I3];
return make_tuple(0, offsetBN0, offsetBN1, 0);
}
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction =
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>,
ShareA>{};
static_for<0, BN0, 1>{}([&](auto bn0) {
static_for<0, BM0, 1>{}([&](auto bm0) {
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, bm0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, bn0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(bm0, I0, bn0, I0));
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, bm0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, bn0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(bm0, I0, bn0, I0));
});
});
});
}
private:
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThread>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThread>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThread, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThread, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp"
namespace ck {
/**
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
* thread by sharing the data between threads in a lanegroup.
*
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
* In total, the algorithm runs using
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template <index_t BlockSize,
typename ABDataType,
typename AccDataType,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerDpp,
index_t NPerDpp,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto dpp_gemm = DppGemm<ABDataType, MPerDpp, NPerDpp, KPack>{};
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
MRepeat * NRepeat,
dpp_gemm.GetRegSizePerDpp(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
const auto dpp_a_idx_k = dpp_a_idx[I0];
const auto dpp_a_idx_m = dpp_a_idx[I1];
return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
}
__device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
const auto dpp_b_idx_k = dpp_b_idx[I0];
const auto dpp_b_idx_n = dpp_b_idx[I1];
return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
const auto blk_m_offset = blk_idx[I0];
const auto blk_n_offset = blk_idx[I1];
constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_m_offset))[I0];
const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_n_offset))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"Wrong! Block descriptors should be known at the time of compilation.");
#if defined(__HIP_DEVICE_COMPILE__)
// Host wave size can be different than the device one and this assert could fail for host,
// but it does matter only for device.
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
#endif
static_assert(MPerBlock % (MPerDpp * MRepeat) == 0,
"Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat.");
static_assert(NPerBlock % (NPerDpp * NRepeat) == 0,
"Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat.");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
constexpr auto M = c_m_n_tblk_lens[I0];
constexpr auto N = c_m_n_tblk_lens[I1];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
constexpr auto M = c_m_n_tblk_lens[I0];
constexpr auto N = c_m_n_tblk_lens[I1];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerDpp>{},
Number<NPerDpp>{}));
return c_block_desc_m0_n0_m1_n1_m2_n2;
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerDpp>{},
Number<NPerDpp>{}));
return c_block_desc_g_m0_n0_m1_n1_m2_n2;
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return c_grid_desc_m0_n0_m1_n1_m2_n2;
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerDpp>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerDpp>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<ABDataType, KPack> a_thread_vec;
vector_type<ABDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using dpp_input_type =
typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
dpp_gemm.template Run(a_thread_vec.template AsType<dpp_input_type>(),
b_thread_vec.template AsType<dpp_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// C[M, N, NumRegDpp]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
ABDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
ABDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()};
};
} // namespace ck
...@@ -4,27 +4,13 @@ ...@@ -4,27 +4,13 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck { namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1> template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
enum struct GemmDlAlgorithm
{
Default, // Uses DOT vector instructions
Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -60,7 +59,6 @@ template < ...@@ -60,7 +59,6 @@ template <
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
...@@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector>;
GemmDlAlg>;
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
...@@ -375,8 +372,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -375,8 +372,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
true, true>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -402,8 +398,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -402,8 +398,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
false, false>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -429,8 +424,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -429,8 +424,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
true, true>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -456,8 +450,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -456,8 +450,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
false, false>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -492,16 +485,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -492,16 +485,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
if(ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
return false;
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102")
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
GemmDlAlgorithm::Dpp8>
{
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDlDpp8"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerDpp,
ck::index_t NPerDpp,
ck::index_t MDppPerWave,
ck::index_t NDppPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1,
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmDpp : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp<
BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
MPerBlock,
NPerBlock,
KPerBlock,
MPerDpp,
NPerDpp,
AK1,
BK1,
MDppPerWave,
NDppPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumPrefetch,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
karg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
}
const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{
const auto kernel = kernel_gemm_dpp<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_gemm_dpp<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(karg);
}
return false;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemmDpp"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerDpp << ", "
<< NPerDpp << ", "
<< MDppPerWave << ", "
<< MDppPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -7,11 +7,9 @@ ...@@ -7,11 +7,9 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...@@ -19,8 +17,6 @@ ...@@ -19,8 +17,6 @@
namespace ck { namespace ck {
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -29,8 +25,7 @@ template <typename GridwiseGemm, ...@@ -29,8 +25,7 @@ template <typename GridwiseGemm,
typename CGridDesc_M0_M10_M11_N0_N10_N11, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop>
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -43,13 +38,6 @@ __global__ void ...@@ -43,13 +38,6 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return;
}
#endif
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -100,8 +88,7 @@ template <index_t BlockSize, ...@@ -100,8 +88,7 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
struct GridwiseGemmDl_km_kn_mn_v1r3 struct GridwiseGemmDl_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
__host__ __device__ static constexpr auto GetBlockwiseGemm()
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
else
{
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
...@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>(); BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
......
This diff is collapsed.
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace ck {
enum struct DppInstr
{
dpp8_f16_16x16x2 = 0,
dpp8_f16_8x32x2,
dpp8_f16_32x8x2
};
/**
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
*
* Not all the combinarions are supported now, for current restrictions see the static asserts
* in the DppSelector's contructor.
*/
template <DppInstr instr>
struct dpp_type;
template <>
struct dpp_type<DppInstr::dpp8_f16_32x8x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 32;
static constexpr index_t n_per_wave = 8;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 16;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector
{
template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
static constexpr auto GetDpp();
template <>
static constexpr auto GetDpp<half_t, 8, 32>()
{
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
{
return DppInstr::dpp8_f16_16x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 32, 8>()
{
return DppInstr::dpp8_f16_32x8x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector()
{
static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0);
static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0);
static_assert(selected_dpp.k_per_dpp % 2 == 0);
static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0);
constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size;
constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave;
constexpr index_t num_dpp_c_elems =
selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup;
static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
if constexpr(selected_dpp.share_a)
{
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0);
static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread ==
selected_dpp.lanegroup_size);
}
else
{
static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0);
static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread ==
selected_dpp.lanegroup_size);
static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread);
}
// Below checks come from the restrictions of the current implementation, could be removed
// in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
}
static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
struct DppGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__host__ __device__ constexpr DppGemm()
{
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
"MPerDpp must be either 8, 16 or 32.");
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
"NPerDpp must be either 8, 16 or 32.");
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
}
__device__ static constexpr index_t GetRegSizePerDpp()
{
return MPerDpp * NPerDpp / dpp_instr.wave_size;
}
template <class ADataType, class BDataType, class CDataType>
__device__ void
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
{
static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
is_same<BaseType, int8_t>::value || is_same<BaseType, f8_t>::value,
"base BaseType must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
});
}
__device__ static auto GetLaneIdInWave()
{
return get_thread_local_1d_id() % dpp_instr.wave_size;
}
__device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
__device__ static auto GetLaneIdInLaneGroup()
{
return get_thread_local_1d_id() % dpp_instr.lanegroup_size;
}
__device__ static auto GetLaneGroupIdInWave()
{
return GetLaneIdInWave() / dpp_instr.lanegroup_size;
}
__device__ static auto GetDppOpIdx()
{
const auto lanegroupId = GetLaneGroupIdInWave();
constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
make_multi_index(lanegroupId));
const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1];
return make_tuple(m_dpp_idx, n_dpp_idx);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
{
const auto laneId = get_thread_local_1d_id();
const auto wave_row = laneId / dpp_instr.n_per_wave;
auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
return make_tuple(0, m_idx % dpp_instr.m_per_wave);
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
{
const auto laneId = get_thread_local_1d_id();
return make_tuple(0, laneId % dpp_instr.n_per_wave);
}
__device__ static CIndex GetBeginOfThreadBlk()
{
const auto dpp_op_idx = GetDppOpIdx();
const auto m_dpp_op_idx = dpp_op_idx[I0];
const auto n_dpp_op_idx = dpp_op_idx[I1];
index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
return CIndex{m_offset, n_offset};
}
static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
static constexpr auto dpp_instr = dpp.selected_dpp;
static constexpr auto K0PerDpp = 1;
static constexpr auto K1PerDpp = dpp.GetK1PerDpp();
__host__ __device__ static constexpr auto GetCMNThreadBlkLengths()
{
return make_tuple(Number<dpp_instr.m_per_thread>{}, Number<dpp_instr.n_per_thread>{});
}
};
} // namespace ck
...@@ -5,17 +5,63 @@ ...@@ -5,17 +5,63 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/amd_gemm_dpp.hpp" #include "ck/utility/inner_product_dpp8.hpp"
namespace ck { namespace ck {
namespace dpp8 { namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers. template <class ABDataType>
constexpr index_t lane_group_size = 8; struct dpp_datatypes;
__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } template <>
__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } struct dpp_datatypes<half_t>
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using a_dtype = half_t;
using b_dtype = half_t;
using c_dtype = float;
static constexpr index_t k_per_instr = 2;
};
template <index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
class BaseInputType,
class AVecDataType,
class BVecDataType,
class CVecDataType,
bool ShareA>
struct DppLanegroupGemm
{
using datatypes_conf = dpp_datatypes<BaseInputType>;
using ADataType = typename datatypes_conf::a_dtype;
using BDataType = typename datatypes_conf::b_dtype;
using CDataType = typename datatypes_conf::c_dtype;
__device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
{
constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;
const vector_type<ADataType, KPerThread> a_vector{a_vec};
const vector_type<BDataType, KPerThread> b_vector{b_vec};
static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) {
float c = c_vec.template AsType<CDataType>()(c_idx);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr index_t source_lane = c_idx;
static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) {
const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
ck::dpp8::
inner_product_dpp<AVecDataType, BVecDataType, CDataType, source_lane, ShareA>(
a_k_vec, b_k_vec, c);
});
c_vec.template AsType<CDataType>()(c_idx) = c;
});
}
};
} // namespace dpp8 } // namespace dpp8
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "amd_gemm_dpp.hpp" #include "amd_gemm_dpp.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "type_convert.hpp" #include "type_convert.hpp"
...@@ -10,6 +11,9 @@ namespace ck { ...@@ -10,6 +11,9 @@ namespace ck {
namespace dpp8 { namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers.
constexpr index_t lane_group_size = 8;
template <int SrcLaneIdx> template <int SrcLaneIdx>
__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
} // namespace ck
...@@ -23,7 +23,7 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( ...@@ -23,7 +23,7 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
...@@ -38,7 +38,7 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( ...@@ -38,7 +38,7 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
...@@ -53,7 +53,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( ...@@ -53,7 +53,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
...@@ -68,7 +68,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( ...@@ -68,7 +68,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
...@@ -374,7 +374,7 @@ struct DeviceOperationInstanceFactory< ...@@ -374,7 +374,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
} }
...@@ -385,7 +385,7 @@ struct DeviceOperationInstanceFactory< ...@@ -385,7 +385,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
...@@ -397,7 +397,7 @@ struct DeviceOperationInstanceFactory< ...@@ -397,7 +397,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
} }
...@@ -408,7 +408,7 @@ struct DeviceOperationInstanceFactory< ...@@ -408,7 +408,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
......
...@@ -31,10 +31,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) ...@@ -31,10 +31,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp)
endif() endif()
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
>;
void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment