"vscode:/vscode.git/clone" did not exist on "ef18b0eda28b37082d158fade59a24b29f6a986c"
Commit b134b7d6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents 090ba885 9f71ff48
......@@ -2,6 +2,7 @@
#define DEVICE_BASE_CPU_HPP
#include <string>
#include "stream_config.hpp"
namespace ck {
namespace tensor_operation {
......@@ -23,7 +24,7 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0;
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}, int = 1) = 0;
virtual ~BaseInvoker() {}
};
......
......@@ -690,7 +690,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
......@@ -743,9 +745,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
......
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#pragma once
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
#include "thread_group.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -25,7 +39,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
......@@ -55,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
......@@ -122,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
......@@ -301,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
}
private:
protected:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
......@@ -338,5 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k;
using Base::A_K1;
using Base::b_block_desc_n0_n1_n2_k;
using Base::B_K1;
using Base::c_thread_buf_;
using Base::c_thread_desc_;
using Base::CalculateAThreadOriginDataIndex;
using Base::CalculateBThreadOriginDataIndex;
using Base::I0;
using Base::I1;
using Base::KPerThread;
using Base::xdlops_gemm;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
// 2-wave optimized blockwise gemm
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, k),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, k),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
__builtin_amdgcn_sched_barrier();
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
asm volatile("s_barrier" ::);
__builtin_amdgcn_sched_barrier();
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if constexpr(k.value == KPerThread - KPerInnerLoop &&
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier();
block_sync_lds();
__builtin_amdgcn_sched_barrier();
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier();
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier();
}
});
});
});
__builtin_amdgcn_sched_barrier();
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier();
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
// B[N0, N1, N2, KPerInnerLoop]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
};
} // namespace ck
#endif
......@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -13,7 +11,7 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
template <typename ThreadGroup,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
......@@ -35,7 +33,7 @@ template <index_t BlockSize,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1
struct ThreadGroupTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
......@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
......@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
dst_element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
......@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
......@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
......@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
......@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
......@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
template <typename ThreadGroup,
typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
......@@ -28,19 +26,19 @@ template <index_t BlockSize,
index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1
struct ThreadGroupTensorSliceTransfer_v6r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
......@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
}
......@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
......@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
......@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
template <index_t BlockSize,
template <typename ThreadGroup,
typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
......@@ -31,21 +29,21 @@ template <index_t BlockSize,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2
struct ThreadGroupTensorSliceTransfer_v6r2
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
......@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
}
......@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
......@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
......@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
......@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
template <typename ThreadGroup,
typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
......@@ -34,23 +32,23 @@ template <index_t BlockSize,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3
struct ThreadGroupTensorSliceTransfer_v6r3
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
......@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<Src2Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
......@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
......@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
......@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
......@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
}
......@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
......@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
};
} // namespace ck
#endif
#ifndef DEVICE_BASE_HPP
#define DEVICE_BASE_HPP
#pragma once
#include <string>
#include "stream_config.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -22,7 +23,10 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0;
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {}
};
......@@ -33,8 +37,8 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) = 0;
virtual std::string GetTypeString() const = 0;
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual ~BaseOperator() {}
};
......@@ -42,4 +46,3 @@ struct BaseOperator
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -44,8 +43,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op,
const D1ReduceOperation d1_reduce_op,
const D1ElementwiseOperation d1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -82,8 +80,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -99,8 +96,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = d1_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -110,6 +106,9 @@ __global__ void
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -125,6 +124,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -157,12 +157,12 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>
D1ElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
......@@ -564,6 +564,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
......@@ -603,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>;
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
......@@ -624,8 +626,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
......@@ -639,17 +640,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(),
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize(),
d_grid_desc_m_.GetElementSpaceSize(),
d_grid_desc_m_.GetElementSpaceSize()},
compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op},
d1_reduce_op_{d1_reduce_op}
d1_element_op_{d1_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
......@@ -684,8 +685,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_;
D1ReduceOperation d1_reduce_op_;
D1ElementwiseOperation d1_element_op_;
};
// Invoker
......@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
......@@ -726,11 +726,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
if(has_main_k0_block_loop)
float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
GridwiseGemm,
......@@ -740,8 +740,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -750,27 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>,
true>;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
else
{
......@@ -782,8 +782,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -792,36 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>,
false>;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
return 0;
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......@@ -865,8 +866,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount)
{
return Argument{p_a,
......@@ -883,8 +883,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount};
}
......@@ -905,8 +904,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -923,8 +921,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount);
}
......
......@@ -107,7 +107,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_base_ptr_of_batch_;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize()},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -427,7 +428,7 @@ struct DeviceBatchedGemmXdl
{
using Argument = DeviceBatchedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
......@@ -455,13 +456,12 @@ struct DeviceBatchedGemmXdl
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
......@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
......@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
......@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
......@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
......@@ -603,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
true>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -636,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
false>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -656,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......@@ -698,7 +698,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......
......@@ -642,7 +642,7 @@ struct
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
......@@ -698,13 +698,12 @@ struct
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v3r3<
GridwiseGemm,
......@@ -728,8 +727,8 @@ struct
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -772,8 +771,8 @@ struct
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -796,9 +795,10 @@ struct
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
......@@ -607,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
......@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v3r2<
GridwiseGemm,
......@@ -687,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -726,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -748,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......@@ -919,4 +917,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
......@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v3r1<
GridwiseGemm,
......@@ -664,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -698,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -718,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
......@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
......@@ -499,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -530,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -550,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -4,7 +4,7 @@
#include <iostream>
#include <memory>
#include <sstream>
#include "conv_fwd_util.hpp"
#include "conv_util.hpp"
#include "device.hpp"
#include "device_conv_fwd.hpp"
#include "common_header.hpp"
......@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
......@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
WeiElementwiseOperation,
OutElementwiseOperation>;
float ave_time = launch_and_time_kernel(naive_conv3d_fwd,
nrepeat,
float ave_time = launch_and_time_kernel(stream_config,
naive_conv3d_fwd,
dim3(256),
dim3(256),
0,
......@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
......@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
OutElementwiseOperation,
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -1241,7 +1241,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
......@@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
if(has_main_k0_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
......@@ -1317,8 +1316,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
true>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -1350,8 +1349,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
false>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -1370,9 +1369,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......@@ -1413,7 +1413,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
// Gridwise GEMM size
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment