Commit b8d37c10 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent d0b9a467
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
...@@ -13,7 +12,6 @@ ...@@ -13,7 +12,6 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_producer_consumer_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -45,28 +43,12 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -45,28 +43,12 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
#if 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
#endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 512 #define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 2
#endif #endif
// check GPU target // check GPU target
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "thread_group.hpp"
namespace ck { namespace ck {
template <typename ThreadGroup, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
...@@ -23,6 +24,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -23,6 +24,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
...@@ -53,7 +56,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -53,7 +56,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = ThreadGroup::GetThreadId(); const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
...@@ -120,8 +123,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -120,8 +123,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(ThreadGroup::GetNumOfThread() == MWaves * NWaves * WaveSize, static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThreadGroup::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
......
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v3r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r1<SrcData,
DstData,
SrcDesc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
...@@ -14,7 +14,7 @@ namespace ck { ...@@ -14,7 +14,7 @@ namespace ck {
template <typename ThreadGroup, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -30,7 +30,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -30,7 +30,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -54,7 +54,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -54,7 +54,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor // 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate // 3. Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -31,21 +29,21 @@ template <index_t BlockSize, ...@@ -31,21 +29,21 @@ template <index_t BlockSize,
bool ThreadTransferSrc0ResetCoordinateAfterRun, bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun, bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2 struct ThreadGroupTensorSliceTransfer_v6r2
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin, const Index& src0_block_slice_origin,
const Src1Desc& src1_desc, const Src1Desc& src1_desc,
const Index& src1_block_slice_origin, const Index& src1_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc, : threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src1_desc, src1_desc,
...@@ -64,17 +62,17 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -64,17 +62,17 @@ struct BlockwiseTensorSliceTransfer_v6r2
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
} }
...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
} }
...@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
} }
...@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -34,23 +32,23 @@ template <index_t BlockSize, ...@@ -34,23 +32,23 @@ template <index_t BlockSize,
bool ThreadTransferSrc1ResetCoordinateAfterRun, bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun, bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3 struct ThreadGroupTensorSliceTransfer_v6r3
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin, const Index& src0_block_slice_origin,
const Src1Desc& src1_desc, const Src1Desc& src1_desc,
const Index& src1_block_slice_origin, const Index& src1_block_slice_origin,
const Src2Desc& src2_desc, const Src2Desc& src2_desc,
const Index& src2_block_slice_origin, const Index& src2_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc, : threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src1_desc, src1_desc,
...@@ -72,14 +70,14 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -72,14 +70,14 @@ struct BlockwiseTensorSliceTransfer_v6r3
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
...@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run( threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
...@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
} }
...@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
} }
...@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
} }
...@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
}; };
} // namespace ck } // namespace ck
#endif
#pragma once
#include "common_header.hpp"
namespace ck {
template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
index_t NumGemmKPrefetchStage>
struct GridwiseGemmPipelineProducerConsumer;
// 1-stage prefetch
template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
struct GridwiseGemmPipelineProducerConsumer<ABBlockTransferThreadGroup, BlockGemmThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 2 > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_sync_lds();
// move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_sync_lds();
// LDS write num_loop - 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
// GEMM num_loop - 1
}
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
a_block_desc,
a_block_copy,
a_grid_buf,
a_block_buf,
a_block_copy_step,
b_grid_desc,
b_block_desc,
b_block_copy,
b_grid_buf,
b_block_buf,
b_block_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
RunBlockGemmPipeline<HasMainLoop>(
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
}
}
};
} // namespace ck
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp" #include "reduction_functions_threadwise.hpp"
...@@ -51,20 +51,20 @@ __global__ void ...@@ -51,20 +51,20 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_d0_grid, p_d0_grid,
p_d1_grid, p_d1_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d1_element_op, d1_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -403,28 +403,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -403,28 +403,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_AK1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -434,28 +434,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -434,28 +434,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -474,7 +474,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -474,7 +474,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
...@@ -636,8 +636,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -636,8 +636,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize, ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
......
...@@ -422,7 +422,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -422,7 +422,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -287,7 +287,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -287,7 +287,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}(); }();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -386,28 +386,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -386,28 +386,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -417,28 +417,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -417,28 +417,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -455,7 +455,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -455,7 +455,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
...@@ -422,27 +422,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -422,27 +422,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
3, 3,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -452,27 +452,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -452,27 +452,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadGroup,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
...@@ -411,27 +411,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -411,27 +411,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}(); }();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
3, 3,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -441,27 +441,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -441,27 +441,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -478,7 +478,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -478,7 +478,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
...@@ -662,8 +662,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -662,8 +662,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global // LDS to global
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize, ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
#include "tensor_space_filling_curve.hpp" #include "tensor_space_filling_curve.hpp"
...@@ -405,28 +405,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -405,28 +405,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -436,28 +436,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -436,28 +436,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -646,7 +646,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -646,7 +646,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global // LDS to global
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r2.hpp" #include "thread_group_tensor_slice_transfer_v6r2.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -426,28 +426,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -426,28 +426,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -457,28 +457,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -457,28 +457,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -495,7 +495,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -495,7 +495,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -664,8 +664,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -664,8 +664,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r2< auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2<
BlockSize, // index_t BlockSize, ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP #pragma once
#define CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r3.hpp" #include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -447,27 +445,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -447,27 +445,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -477,27 +475,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -477,27 +475,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -514,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -514,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -684,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -684,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r3< auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
BlockSize, // index_t BlockSize, ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
...@@ -826,4 +824,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -826,4 +824,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}; };
} // namespace ck } // namespace ck
#endif
...@@ -21,9 +21,9 @@ struct TupleElement ...@@ -21,9 +21,9 @@ struct TupleElement
{ {
__host__ __device__ constexpr TupleElement() = default; __host__ __device__ constexpr TupleElement() = default;
template <typename T, template <
typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, typename T,
bool>::type = false> typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr Tuple() = default; __host__ __device__ constexpr Tuple() = default;
template <typename Y, template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
!is_same<remove_cvref_t<Y>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment