Commit b134b7d6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents 090ba885 9f71ff48
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DEVICE_BASE_CPU_HPP #define DEVICE_BASE_CPU_HPP
#include <string> #include <string>
#include "stream_config.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -23,7 +24,7 @@ struct BaseInvoker ...@@ -23,7 +24,7 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default; BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}, int = 1) = 0;
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
......
...@@ -690,7 +690,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -690,7 +690,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
...@@ -743,9 +745,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -743,9 +745,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
} }
}; };
......
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP #pragma once
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "thread_group.hpp"
namespace ck { namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -25,7 +39,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -25,7 +39,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
...@@ -55,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -55,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
...@@ -122,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -122,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
...@@ -301,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -301,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}); });
} }
private: protected:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
...@@ -338,5 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -338,5 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
}; };
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k;
using Base::A_K1;
using Base::b_block_desc_n0_n1_n2_k;
using Base::B_K1;
using Base::c_thread_buf_;
using Base::c_thread_desc_;
using Base::CalculateAThreadOriginDataIndex;
using Base::CalculateBThreadOriginDataIndex;
using Base::I0;
using Base::I1;
using Base::KPerThread;
using Base::xdlops_gemm;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
// 2-wave optimized blockwise gemm
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, k),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, k),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
__builtin_amdgcn_sched_barrier();
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
asm volatile("s_barrier" ::);
__builtin_amdgcn_sched_barrier();
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if constexpr(k.value == KPerThread - KPerInnerLoop &&
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier();
block_sync_lds();
__builtin_amdgcn_sched_barrier();
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier();
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier();
}
});
});
});
__builtin_amdgcn_sched_barrier();
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier();
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
// B[N0, N1, N2, KPerInnerLoop]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
};
} // namespace ck } // namespace ck
#endif
...@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1 ...@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,7 +11,7 @@ namespace ck { ...@@ -13,7 +11,7 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
...@@ -35,7 +33,7 @@ template <index_t BlockSize, ...@@ -35,7 +33,7 @@ template <index_t BlockSize,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1 struct ThreadGroupTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
...@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op, const SrcElementwiseOperation& src_element_op,
...@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
dst_element_op) dst_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
...@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
} }
...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
DstBuffer& dst_buf, DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
} }
...@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
} }
...@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -28,19 +26,19 @@ template <index_t BlockSize, ...@@ -28,19 +26,19 @@ template <index_t BlockSize,
index_t ScalarPerVector, index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1 struct ThreadGroupTensorSliceTransfer_v6r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc, : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
dst_desc, dst_desc,
...@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1
element_op) element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
} }
...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
} }
...@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor // 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate // 3. Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -31,21 +29,21 @@ template <index_t BlockSize, ...@@ -31,21 +29,21 @@ template <index_t BlockSize,
bool ThreadTransferSrc0ResetCoordinateAfterRun, bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun, bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2 struct ThreadGroupTensorSliceTransfer_v6r2
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin, const Index& src0_block_slice_origin,
const Src1Desc& src1_desc, const Src1Desc& src1_desc,
const Index& src1_block_slice_origin, const Index& src1_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc, : threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src1_desc, src1_desc,
...@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2
element_op) element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() && nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
} }
...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
} }
...@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
} }
...@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -34,23 +32,23 @@ template <index_t BlockSize, ...@@ -34,23 +32,23 @@ template <index_t BlockSize,
bool ThreadTransferSrc1ResetCoordinateAfterRun, bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun, bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3 struct ThreadGroupTensorSliceTransfer_v6r3
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin, const Index& src0_block_slice_origin,
const Src1Desc& src1_desc, const Src1Desc& src1_desc,
const Index& src1_block_slice_origin, const Index& src1_block_slice_origin,
const Src2Desc& src2_desc, const Src2Desc& src2_desc,
const Index& src2_block_slice_origin, const Index& src2_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc, : threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src1_desc, src1_desc,
...@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3
element_op) element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() && nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() && nDim == remove_cvref_t<Src2Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
...@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run( threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
...@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
} }
...@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
} }
...@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
} }
...@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
}; };
} // namespace ck } // namespace ck
#endif
#ifndef DEVICE_BASE_HPP #pragma once
#define DEVICE_BASE_HPP
#include <string> #include <string>
#include "stream_config.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -22,7 +23,10 @@ struct BaseInvoker ...@@ -22,7 +23,10 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default; BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
...@@ -33,8 +37,8 @@ struct BaseOperator ...@@ -33,8 +37,8 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) = 0; virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const = 0; virtual std::string GetTypeString() const { return ""; }
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
...@@ -42,4 +46,3 @@ struct BaseOperator ...@@ -42,4 +46,3 @@ struct BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -21,8 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation,
typename D1ReduceOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -44,8 +43,7 @@ __global__ void ...@@ -44,8 +43,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op, const D1ElementwiseOperation d1_element_op,
const D1ReduceOperation d1_reduce_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -82,8 +80,7 @@ __global__ void ...@@ -82,8 +80,7 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -99,8 +96,7 @@ __global__ void ...@@ -99,8 +96,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = d0_reduce_op; ignore = d1_element_op;
ignore = d1_reduce_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -110,6 +106,9 @@ __global__ void ...@@ -110,6 +106,9 @@ __global__ void
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
} }
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -125,6 +124,7 @@ template <typename ALayout, ...@@ -125,6 +124,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -157,12 +157,12 @@ template <typename ALayout, ...@@ -157,12 +157,12 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation>
D1ReduceOperation>
{ {
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
...@@ -564,6 +564,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -564,6 +564,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -603,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -603,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
...@@ -624,8 +626,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -624,8 +626,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t BatchCount) index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
...@@ -639,17 +640,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -639,17 +640,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, d_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(), compute_base_ptr_of_batch_{
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
c_grid_desc_m_n_.GetElementSpaceSize(), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
d_grid_desc_m_.GetElementSpaceSize(), type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
d_grid_desc_m_.GetElementSpaceSize()}, type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{}, block_2_ctile_map_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op}, d1_element_op_{d1_element_op}
d1_reduce_op_{d1_reduce_op}
{ {
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
...@@ -684,8 +685,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -684,8 +685,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_; D1ElementwiseOperation d1_element_op_;
D1ReduceOperation d1_reduce_op_;
}; };
// Invoker // Invoker
...@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -726,11 +726,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -726,11 +726,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
if(has_main_k0_block_loop) float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -740,8 +740,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -740,8 +740,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -750,27 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -750,27 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
launch_kernel(kernel, elapsed_time =
dim3(grid_size), launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.p_d0_grid_, arg.p_b_grid_,
arg.p_d1_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.p_d0_grid_,
arg.a_element_op_, arg.p_d1_grid_,
arg.b_element_op_, arg.BatchCount_,
arg.c_element_op_, arg.a_element_op_,
arg.d0_reduce_op_, arg.b_element_op_,
arg.d1_reduce_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.d1_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_bk0_n_bk1_,
arg.d_grid_desc_mblock_mperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_base_ptr_of_batch_, arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -782,8 +782,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -782,8 +782,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -792,36 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -792,36 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
launch_kernel(kernel, elapsed_time =
dim3(grid_size), launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.p_d0_grid_, arg.p_b_grid_,
arg.p_d1_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.p_d0_grid_,
arg.a_element_op_, arg.p_d1_grid_,
arg.b_element_op_, arg.BatchCount_,
arg.c_element_op_, arg.a_element_op_,
arg.d0_reduce_op_, arg.b_element_op_,
arg.d1_reduce_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.d1_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_bk0_n_bk1_,
arg.d_grid_desc_mblock_mperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_base_ptr_of_batch_, arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
} }
return 0; return elapsed_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -865,8 +866,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -865,8 +866,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t BatchCount) index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
...@@ -883,8 +883,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -883,8 +883,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
BatchCount}; BatchCount};
} }
...@@ -905,8 +904,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -905,8 +904,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t BatchCount) override index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -923,8 +921,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -923,8 +921,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
BatchCount); BatchCount);
} }
......
...@@ -107,7 +107,7 @@ __global__ void ...@@ -107,7 +107,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = compute_base_ptr_of_batch_; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl ...@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(), compute_ptr_offset_of_batch_{
b_grid_desc_k0_n_k1_.GetElementSpaceSize(), type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
c_grid_desc_m_n_.GetElementSpaceSize()}, type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -427,7 +428,7 @@ struct DeviceBatchedGemmXdl ...@@ -427,7 +428,7 @@ struct DeviceBatchedGemmXdl
{ {
using Argument = DeviceBatchedGemmXdl::Argument; using Argument = DeviceBatchedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -455,13 +456,12 @@ struct DeviceBatchedGemmXdl ...@@ -455,13 +456,12 @@ struct DeviceBatchedGemmXdl
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
...@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl ...@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl ...@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl ...@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0) hipGetErrorString(hipMemset(
{ arg.p_c_grid_,
ave_time = 0,
launch_and_time_kernel(kernel, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
nrepeat, sizeof(CDataType)));
dim3(grid_size),
dim3(BlockSize), launch_and_time_kernel(stream_config,
0, kernel,
arg.p_a_grid_, dim3(grid_size),
arg.p_b_grid_, dim3(BlockSize),
arg.p_c_grid_, 0,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.p_a_grid_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.p_b_grid_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.p_c_grid_,
arg.a_element_op_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_element_op_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_element_op_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.a_element_op_,
} arg.b_element_op_,
arg.c_element_op_,
if(kbatch > 1 || nrepeat <= 0) arg.block_2_ctile_map_);
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
...@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
...@@ -603,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -603,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
true>; true>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -636,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -636,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
false>; false>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -656,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -656,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -698,7 +698,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -698,7 +698,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
// Gridwise GEMM size // Gridwise GEMM size
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -642,7 +642,7 @@ struct ...@@ -642,7 +642,7 @@ struct
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -698,13 +698,12 @@ struct ...@@ -698,13 +698,12 @@ struct
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r3< const auto kernel = kernel_gemm_xdlops_v3r3<
GridwiseGemm, GridwiseGemm,
...@@ -728,8 +727,8 @@ struct ...@@ -728,8 +727,8 @@ struct
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -772,8 +771,8 @@ struct ...@@ -772,8 +771,8 @@ struct
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -796,9 +795,10 @@ struct ...@@ -796,9 +795,10 @@ struct
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP #pragma once
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
...@@ -607,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -607,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r2< const auto kernel = kernel_gemm_xdlops_v3r2<
GridwiseGemm, GridwiseGemm,
...@@ -687,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -687,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -726,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -726,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -748,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -748,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -919,4 +917,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -919,4 +917,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r1< const auto kernel = kernel_gemm_xdlops_v3r1<
GridwiseGemm, GridwiseGemm,
...@@ -664,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -664,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -698,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -698,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -718,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -718,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
...@@ -499,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -499,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -530,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -530,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -550,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -550,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment