Commit d78b9359 authored by Jing Zhang's avatar Jing Zhang
Browse files

simple blockwise gemm

parent 956465c6
......@@ -27,8 +27,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 256, 128, 128, 16, 2, 8, 8, 1, S<1, 16>, S<1, 16>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 64, 64, 16, 2, 8, 8, 1, S<1, 8>, S<1, 8>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 64, 64, 16, 2, 4, 4, 1, S<1, 8>, S<1, 8>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 16, 64, 16, 2, 2, 8, 1, S<1, 8>, S<1, 8>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -4,8 +4,8 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_dlops_v3.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp"
namespace ck {
......@@ -13,11 +13,11 @@ template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_E1_K1_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop,
index_t KPerThreadLoop>
typename ABlockDesc_K0_M_K1,
typename BBlockDesc_K0_N_K1,
index_t MPerThread,
index_t NPerThread,
index_t K0PerLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
......@@ -26,105 +26,91 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
static constexpr auto K0 = ABlockDesc_K0_M_K1{}.GetLength(I0);
static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1);
static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2);
static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto N = BBlockDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr auto M0 = M / MPerThread;
static constexpr auto M1 = MPerThread;
static constexpr auto N0 = N / NPerThread;
static constexpr auto N1 = NPerThread;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
make_tuple(Number<K0PerLoop>{}, Number<MPerThread>{}, Number<K1>{}));
static constexpr auto b_thread_mtx_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{},
Number<HoPerThread>{},
Number<WoPerThread>{},
Number<E2>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<NPerThread>{}, Number<K1>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<I1>{}, Number<M1>{}, Number<I1>{}, Number<N1>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I1] * MPerThread, 0)},
b_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I3] * NPerThread, 0)}
{
static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() &&
BBlockDesc_K0_N_K1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BBlockDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BBlockDesc_K0_N_K1{}.GetLength(I2),
"wrong! E dimension not consistent\n");
static_assert(E1 % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
static_assert(K0 % K0PerLoop == 0, "");
static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
WoPerBlock % WoPerThread == 0,
static_assert(M % MPerThread == 0 && N % NPerThread == 0,
"wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = KPerBlock / KPerThread;
constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
static_assert(BlockSize == M0 * N0, "wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
__device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<KPerThread, I1, HoPerThread, WoPerThread>{};
return Sequence<I1, M1, I1, N1>{};
}
__device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
constexpr auto K0 = KPerBlock / KPerThread;
constexpr auto N0 = I1;
constexpr auto H0 = HoPerBlock / HoPerThread;
constexpr auto W0 = WoPerBlock / WoPerThread;
constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
constexpr auto c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(make_merge_transform(make_tuple(I1, M0, I1, N0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_k_n_h_w_thread_cluster_idx =
c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex(
const auto c_m0_m1_n0_n1_thread_cluster_idx =
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
return c_k_n_h_w_thread_cluster_idx;
return c_m0_m1_n0_n1_thread_cluster_idx;
}
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename BBlockBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{};
constexpr auto b_block_mtx = BBlockDesc_K0_N_K1{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, FloatB, b_thread_mtx_.GetElementSpaceSize(), true>
b_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
......@@ -132,46 +118,55 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{};
static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
static_for<0, K0, K0PerLoop>{}([&](auto k0_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin, I0),
make_tuple(k0_begin, I0, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_mtx,
make_tuple(k0_begin, I0, I0),
b_block_buf,
b_thread_mtx_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, I0, I0, I0),
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(k_begin, I0, I0, I0));
});
make_tuple(I0, I0, I0, I0));
});
}
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
}
private:
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA,
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_E1_K1_E2,
ABlockDesc_K0_M_K1,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<K0PerLoop, MPerThread, K1>,
Sequence<0, 1, 2>,
2,
K1,
K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc_K0_N_K1,
decltype(b_thread_mtx_),
Sequence<K0PerLoop, NPerThread, K1>,
Sequence<0, 1, 2>,
2,
E2,
E2>;
K1,
K1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
......
......@@ -371,120 +371,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float ave_time = 0;
#if 0
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
#else
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
true,
true>;
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, true, true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -504,11 +394,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
true,
false>;
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, true, false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -528,11 +414,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
false,
true>;
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, false, true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -552,11 +434,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
false,
false>;
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, false, false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -573,7 +451,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
arg.StrideB_,
arg.StrideC_);
}
#endif
return ave_time;
}
......
......@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
......@@ -198,7 +199,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
static constexpr auto K1Number = Number<K1>{};
__host__ __device__ static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
__host__ __device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
......@@ -237,7 +239,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
}
}
__host__ __device__ static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
__host__ __device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
......@@ -333,7 +336,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
......@@ -420,6 +422,322 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
#if 1
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0),
a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0),
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t MPerThread = M1PerThreadM111;
constexpr index_t NPerThread = N1PerThreadN111;
const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerThread,
NPerThread,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1] * MPerThread,
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3] * NPerThread),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
#else
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
......@@ -710,6 +1028,15 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
if(get_block_1d_id() == 0)
{
printf("%d %d %d %d\n",
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]);
}
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
......@@ -1284,6 +1611,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
c_grid_buf);
}
}
#endif
};
} // namespace ck
......@@ -4,26 +4,21 @@
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "math.hpp"
#include "ck/utility/common_header.hpp"
namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N]
// C[M, N] += transpose(A[M, M]) * B[M, N]
// Element of matrix can be vectorized data
// Assume:
// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
// compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_E1_K_E2,
typename BThreadDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
typename AThreadDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
typename CThreadDesc_M_N,
typename enable_if<AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
......@@ -42,9 +37,9 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
COriginIdx)
{
static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
static_assert(AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
......@@ -61,96 +56,29 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1);
constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
constexpr auto K0 = AThreadDesc_K0_M_K1{}.GetLength(I0);
constexpr auto M = AThreadDesc_K0_M_K1{}.GetLength(I1);
constexpr auto K1 = AThreadDesc_K0_M_K1{}.GetLength(I2);
constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
constexpr auto N = BThreadDesc_K0_N_K1{}.GetLength(I1);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
if constexpr((Ho % 2 == 0) && (Wo % 2 == 0))
{
constexpr auto SubHW = 2;
static_for<0, K, 1>{}([&](auto k) {
static_for<0, Ho, SubHW>{}([&](auto h) {
static_for<0, Wo, SubHW>{}([&](auto w) {
static_for<0, E1, 1>{}([&](auto e1) {
static_for<0, E2, 1>{}([&](auto e2) {
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
a_origin_idx + make_tuple(e1, k, e2));
constexpr index_t b0_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e1, 0, h, w, e2));
constexpr index_t b1_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e1, 0, h, w + 1, e2));
constexpr index_t b2_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e1, 0, h + 1, w, e2));
constexpr index_t b3_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2));
constexpr index_t c0_offset =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
make_tuple(k, 0, h, w));
constexpr index_t c1_offset =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w + 1));
constexpr index_t c2_offset =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h + 1, w));
constexpr index_t c3_offset =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b0_offset>{}],
b_buf[Number<b1_offset>{}],
b_buf[Number<b2_offset>{}],
b_buf[Number<b3_offset>{}],
c_buf(Number<c0_offset>{}),
c_buf(Number<c1_offset>{}),
c_buf(Number<c2_offset>{}),
c_buf(Number<c3_offset>{}));
});
});
});
});
});
}
else
{
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = AThreadDesc_K0_M_K1{}.CalculateOffset(
a_origin_idx + make_tuple(k0, m, k1));
static_for<0, K, 1>{}([&](auto k) {
static_for<0, Ho, 1>{}([&](auto h) {
static_for<0, Wo, 1>{}([&](auto w) {
static_for<0, E1, 1>{}([&](auto e1) {
static_for<0, E2, 1>{}([&](auto e2) {
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
a_origin_idx + make_tuple(e1, k, e2));
constexpr index_t b_offset = BThreadDesc_K0_N_K1{}.CalculateOffset(
b_origin_idx + make_tuple(k0, n, k1));
constexpr index_t b_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e1, 0, h, w, e2));
constexpr index_t c_offset =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
make_tuple(k, 0, h, w));
constexpr index_t c_offset = CThreadDesc_M_N{}.CalculateOffset(
c_origin_idx + make_tuple(0, m, 0, n));
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
......@@ -159,9 +87,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
});
});
});
});
}
}
} // namespace ck
};
} // namespace ck
......
......@@ -9,6 +9,12 @@ namespace ck {
template <typename TA, typename TB, typename TC>
__device__ void inner_product(const TA& a, const TB& b, TC& c);
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
c += a * b;
}
template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment