Commit cfc80c01 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into ck_conv_bww_fp16

parents 69ea9ad9 6d4450ef
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r3.hpp" #include "blockwise_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -24,7 +25,7 @@ template <typename GridwiseGemm, ...@@ -24,7 +25,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -50,7 +51,7 @@ __global__ void ...@@ -50,7 +51,7 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -109,7 +110,8 @@ template < ...@@ -109,7 +110,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01 // check M01, N01
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -303,7 +325,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -303,7 +325,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -324,17 +346,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -324,17 +346,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_m0_n0_block_cluster_adaptor; return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
...@@ -351,9 +373,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -351,9 +373,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C1GridDesc_M_N{}))>; C1GridDesc_M_N{}))>;
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -509,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -509,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS // gridwise GEMM pipeline
{ const auto gridwise_gemm_pipeline =
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); remove_cvref_t<decltype(a_grid_buf)>,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); remove_cvref_t<decltype(a_block_buf)>,
} remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
// Initialize C remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
c_thread_buf.Clear(); remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
// main body remove_cvref_t<decltype(b_block_buf)>,
if constexpr(HasMainKBlockLoop) remove_cvref_t<decltype(b_block_slice_copy_step)>,
{ remove_cvref_t<decltype(blockwise_gemm)>,
index_t k0_block_data_begin = 0; remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
do HasMainK0BlockLoop>{};
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_block_desc_k0_m_k1,
a_blockwise_copy,
block_sync_lds(); a_grid_buf,
a_block_buf,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_block_desc_k0_n_k1,
b_blockwise_copy,
block_sync_lds(); b_grid_buf,
b_block_buf,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_block_slice_copy_step,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); blockwise_gemm,
c_thread_buf,
k0_block_data_begin += K0PerBlock; K0BlockMainLoop);
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -64,9 +64,10 @@ template <typename SliceLengths, ...@@ -64,9 +64,10 @@ template <typename SliceLengths,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to // RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation // save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to // RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation // save addr computation
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1 struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcStepHacks> template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunRead(const SrcDesc& src_desc,
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(src_desc, forward_step_idx);
src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(src_desc, backward_step_idx);
src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -215,8 +214,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -215,8 +214,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}); });
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>( src_thread_scratch_tuple_(thread_scratch_id)
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); .template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() template <index_t ThreadScratchId>
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]); dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple[thread_scratch_id][idx]);
}); });
#else #else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
...@@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
const auto src_vector_refs = generate_tie( const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& { [&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim // i increment corresponds to movement in DstVectorDim
return src_thread_scratch_.GetVectorTypeReference( return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector); data_idx_seq + i * dst_scalar_step_in_vector);
}, },
Number<num_src_vector>{}); Number<num_src_vector>{});
...@@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]); dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
}); });
} }
#endif #endif
} }
template <typename DstBuffer, typename DstStepHacks> template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc,
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
// if there is transpose, it's done here // if there is transpose, it's done here
// TODO move this elsewhere // TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(); TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim // src scalar per access on each dim
// TODO: don't use this // TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(dst_desc, forward_step_idx);
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(dst_desc, backward_step_idx);
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = remove_cvref_t<SrcDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_step_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
// TODO: why need remove_cvref_t ?
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
...@@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr, using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
SrcData, SrcData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true> true>;
src_thread_scratch_;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr, DstData,
DstData, DstScalarPerVector,
DstScalarPerVector, decltype(dst_thread_scratch_desc_),
decltype(dst_thread_scratch_desc_), true>;
true>
dst_thread_scratch_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
......
...@@ -920,10 +920,10 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -920,10 +920,10 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
bool src_thread_element_valid, bool src_thread_element_valid,
index_t src_element_space_size) index_t src_element_space_size)
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size); make_wave_buffer_resource(p_src_wave, src_element_space_size);
......
...@@ -49,7 +49,7 @@ template <typename X, typename... Xs> ...@@ -49,7 +49,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cvref_t<X>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}}; return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...};
} }
// make empty array // make empty array
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "transpose_vectors.hpp" #include "transpose_vectors.hpp"
#include "inner_product.hpp" #include "inner_product.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "debug.hpp"
// TODO: remove this // TODO: remove this
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
......
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
namespace ck {
namespace debug {
namespace detail {
template <typename T, typename Enable = void>
struct PrintAsType;
template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value>
{
using type = float;
};
template <>
struct PrintAsType<ck::half_t, void>
{
using type = float;
};
template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value>
{
using type = int;
};
} // namespace detail
// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer
// and the number of elements. Can optionally specify strides between elements and how many bytes'
// worth of data per row.
//
// Usage example:
//
// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
//
template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements)
{
using PrintType = typename detail::PrintAsType<T>::type;
constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]");
index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
index_t tid =
(threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;
__syncthreads();
if(tid == 0)
{
printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n",
wgid,
row_bytes,
element_stride);
for(index_t i = 0; i < num_elements; i += row_elements)
{
printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride)
{
printf("%.0f ", static_cast<PrintType>(p_shared[i + j]));
}
printf("\n");
}
printf("\n");
}
__syncthreads();
}
} // namespace debug
} // namespace ck
#endif // UTILITY_DEBUG_HPP
...@@ -56,7 +56,7 @@ struct DynamicBuffer ...@@ -56,7 +56,7 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_LOAD
bool constexpr use_amd_buffer_addressing = true; bool constexpr use_amd_buffer_addressing = true;
#else #else
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
...@@ -68,8 +68,7 @@ struct DynamicBuffer ...@@ -68,8 +68,7 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, t_per_x>(
t_per_x>(
p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
...@@ -125,7 +124,7 @@ struct DynamicBuffer ...@@ -125,7 +124,7 @@ struct DynamicBuffer
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_STORE
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x>(
...@@ -291,7 +290,7 @@ struct DynamicBuffer ...@@ -291,7 +290,7 @@ struct DynamicBuffer
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ATOMIC_ADD
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
......
...@@ -13,5 +13,38 @@ struct integral_constant ...@@ -13,5 +13,38 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; } __host__ __device__ constexpr value_type operator()() const noexcept { return value; }
}; };
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator+(integral_constant<TX, X>, integral_constant<TY, Y>)
{
return integral_constant<decltype(X + Y), X + Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator-(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y <= X, "wrong!");
return integral_constant<decltype(X - Y), X - Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator*(integral_constant<TX, X>, integral_constant<TY, Y>)
{
return integral_constant<decltype(X * Y), X * Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator/(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X / Y), X / Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X % Y), X % Y>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -17,6 +17,12 @@ struct is_known_at_compile_time<index_t> ...@@ -17,6 +17,12 @@ struct is_known_at_compile_time<index_t>
static constexpr bool value = false; static constexpr bool value = false;
}; };
template <>
struct is_known_at_compile_time<long_index_t>
{
static constexpr bool value = false;
};
template <typename T, T X> template <typename T, T X>
struct is_known_at_compile_time<integral_constant<T, X>> struct is_known_at_compile_time<integral_constant<T, X>>
{ {
......
...@@ -111,24 +111,39 @@ struct MagicDivision ...@@ -111,24 +111,39 @@ struct MagicDivision
} }
// magic division for uint32_t // magic division for uint32_t
__host__ __device__ static constexpr uint32_t __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{ {
uint32_t tmp = __umulhi(dividend, multiplier); uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift; return (tmp + dividend) >> shift;
} }
__host__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t // magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct // non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended // TODO: figure out how to do magic number divison for int32_t as dividended
__host__ __device__ static constexpr int32_t __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32); uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier); uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift; return (tmp + dividend_u32) >> shift;
} }
__host__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
return (tmp + dividend_u32) >> shift;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -8,37 +8,5 @@ namespace ck { ...@@ -8,37 +8,5 @@ namespace ck {
template <index_t N> template <index_t N>
using Number = integral_constant<index_t, N>; using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck } // namespace ck
#endif #endif
#include "math.hpp"
#include "sequence.hpp"
#include "tensor_adaptor.hpp"
#include "statically_indexed_array_multi_index.hpp"
#include "tuple_helper.hpp"
namespace ck {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension
struct SpaceFillingCurve
{
static constexpr index_t nDim = TensorLengths::Size();
using Index = MultiIndex<nDim>;
static constexpr index_t ScalarPerVector =
reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{});
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
static constexpr auto dim_access_order = DimAccessOrder{};
static constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(ordered_access_lengths)),
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr index_t GetNumOfAccess()
{
return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) /
ScalarPerVector;
}
template <index_t AccessIdx1d>
static __device__ __host__ constexpr auto GetForwardStep(Number<AccessIdx1d>)
{
constexpr auto idx_curr = GetIndex(Number<AccessIdx1d>{});
constexpr auto idx_next = GetIndex(Number<AccessIdx1d + 1>{});
return idx_next - idx_curr;
}
template <index_t AccessIdx1d>
static __device__ __host__ constexpr auto GetBackwardStep(Number<AccessIdx1d>)
{
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
constexpr auto idx_curr = GetIndex(Number<AccessIdx1d>{});
constexpr auto idx_prev = GetIndex(Number<AccessIdx1d - 1>{});
return idx_prev - idx_curr;
}
template <index_t AccessIdx1d>
static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>)
{
#if 0
/*
* \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
#else
constexpr auto access_strides = container_reverse_exclusive_scan(
ordered_access_lengths, math::multiplies{}, Number<1>{});
constexpr auto idx_1d = Number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value;
auto id = 0;
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
id = res / access_strides[kdim].value;
res -= id * access_strides[kdim].value;
});
return id;
};
constexpr auto id = compute_index_impl(idim);
return Number<id>{};
};
constexpr auto ordered_access_idx = generate_tuple(compute_index, Number<nDim>{});
#endif
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto idim) {
index_t tmp = ordered_access_idx[I0];
static_for<1, idim, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(idim) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate multi-dim tensor index
auto idx_md = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 -
ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
ScalarsPerAccess{};
}();
return idx_md;
}
};
} // namespace ck
...@@ -13,6 +13,8 @@ __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size() ...@@ -13,6 +13,8 @@ __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size()
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; }
} // namespace ck } // namespace ck
#endif #endif
...@@ -83,7 +83,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -83,7 +83,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
void* p_a_k_m0_m1_grid_desc, void* p_a_k_m0_m1_grid_desc,
void* p_b_k_n0_n1_grid_desc, void* p_b_k_n0_n1_grid_desc,
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor) void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -194,7 +194,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -194,7 +194,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
auto c_m0_m10_m11_n0_n10_n11_grid_desc = auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor = auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0) if(hipThreadIdx_x == 0)
...@@ -203,8 +203,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -203,8 +203,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; *static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>( *static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>( *static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
}; };
}; };
...@@ -219,7 +219,7 @@ extern "C" __global__ void ...@@ -219,7 +219,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k_m0_m1_grid_desc, const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc, const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -332,14 +332,13 @@ extern "C" __global__ void ...@@ -332,14 +332,13 @@ extern "C" __global__ void
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k_m0_m1_grid_desc = const auto a_k_m0_m1_grid_desc =
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc); *reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
...@@ -348,9 +347,9 @@ extern "C" __global__ void ...@@ -348,9 +347,9 @@ extern "C" __global__ void
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>( *reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); (const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -364,7 +363,7 @@ extern "C" __global__ void ...@@ -364,7 +363,7 @@ extern "C" __global__ void
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor, cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
}; };
...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
void* p_a_k0_m_k1_grid_desc, void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc, void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc, void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor) void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor = auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0) if(hipThreadIdx_x == 0)
...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
b_k0_n_k1_grid_desc; b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) = *static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc; c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>( *static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
} }
}; };
...@@ -215,7 +215,7 @@ extern "C" __global__ void ...@@ -215,7 +215,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -325,12 +325,11 @@ extern "C" __global__ void ...@@ -325,12 +325,11 @@ extern "C" __global__ void
constexpr auto c_m0_m1_m2_n_grid_desc_tmp = constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc = const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc); *reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
...@@ -338,9 +337,9 @@ extern "C" __global__ void ...@@ -338,9 +337,9 @@ extern "C" __global__ void
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc); *reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc = const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc); *reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); (const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -354,5 +353,5 @@ extern "C" __global__ void ...@@ -354,5 +353,5 @@ extern "C" __global__ void
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor); cblockid_to_m0_n0_block_cluster_adaptor);
}; };
...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
void* p_a_k0_m_k1_grid_desc, void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc, void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc, void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor) void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor = auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0) if(hipThreadIdx_x == 0)
...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
b_k0_n_k1_grid_desc; b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) = *static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc; c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>( *static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
} }
}; };
...@@ -215,7 +215,7 @@ extern "C" __global__ void ...@@ -215,7 +215,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -324,12 +324,11 @@ extern "C" __global__ void ...@@ -324,12 +324,11 @@ extern "C" __global__ void
false>; false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp = constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc = const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc); *reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
...@@ -337,9 +336,9 @@ extern "C" __global__ void ...@@ -337,9 +336,9 @@ extern "C" __global__ void
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc); *reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc = const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc); *reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); (const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -353,5 +352,5 @@ extern "C" __global__ void ...@@ -353,5 +352,5 @@ extern "C" __global__ void
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor); cblockid_to_m0_n0_block_cluster_adaptor);
}; };
...@@ -26,11 +26,16 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ...@@ -26,11 +26,16 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
) ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
)
# device_gemm_bias_2d_instance # device_gemm_bias_2d_instance
set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE
...@@ -76,6 +81,11 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ...@@ -76,6 +81,11 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
) )
# device_conv1d_fwd_instance
set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
)
# device_conv2d_fwd_bias_relu_instance # device_conv2d_fwd_bias_relu_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
...@@ -96,16 +106,18 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S ...@@ -96,16 +106,18 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_conv1d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
...@@ -116,6 +128,7 @@ target_compile_features(device_gemm_bias_2d_instance PUBLIC) ...@@ -116,6 +128,7 @@ target_compile_features(device_gemm_bias_2d_instance PUBLIC)
target_compile_features(device_gemm_bias_relu_instance PUBLIC) target_compile_features(device_gemm_bias_relu_instance PUBLIC)
target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) target_compile_features(device_gemm_bias_relu_add_instance PUBLIC)
target_compile_features(device_batched_gemm_instance PUBLIC) target_compile_features(device_batched_gemm_instance PUBLIC)
target_compile_features(device_conv1d_fwd_instance PUBLIC)
target_compile_features(device_conv2d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_instance PUBLIC)
target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC)
target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC)
...@@ -126,6 +139,7 @@ set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDE ...@@ -126,6 +139,7 @@ set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDE
set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
...@@ -136,7 +150,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib) ...@@ -136,7 +150,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib)
install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
#ifndef CONV_UTILS_HPP
#define CONV_UTILS_HPP
#include <cstdlib>
#include <functional>
#include <iterator>
#include <numeric>
#include <sstream>
#include <type_traits>
#include <vector>
#include "config.hpp"
#include "host_tensor.hpp"
#include "tensor_layout.hpp"
namespace ck {
namespace conv_util {
/**
* @brief Calculate number of FLOPs for Convolution
*
* @param[in] N Batch size.
* @param[in] C Number of input channels.
* @param[in] K Number of output channels.
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
* @param[in] output_spatial_lengths Convolution output spatial dimensions
* lengths.
*
* @return The number of flops.
*/
std::size_t GetFlops(ck::index_t N,
ck::index_t C,
ck::index_t K,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths)
{
// 2 * N * K * <output spatial lengths product> * C * <filter spatial lengths product>
return static_cast<std::size_t>(2) * N * K *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()) *
C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
}
/**
* @brief Calculate number of bytes read/write by convolution algorithm.
*
* @param[in] N Batch size.
* @param[in] C Number of input channels.
* @param[in] K Number of output channels.
* @param[in] input_spatial_lengths Input spatial dimensions lengths.
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
* @param[in] output_spatial_lengths Output spatial dimensions lengths
*
* @tparam InDataType Input tensor data type.
* @tparam WeiDataType Weights tensor data type.
* @tparam OutDataType Output tensor data type.
*
* @return The number of used bytes.
*/
template <typename InDataType = float,
typename WeiDataType = InDataType,
typename OutDataType = InDataType>
std::size_t GetBtype(ck::index_t N,
ck::index_t C,
ck::index_t K,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths)
{
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
return sizeof(InDataType) * (N * C *
std::accumulate(std::begin(input_spatial_lengths),
std::end(input_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(WeiDataType) * (K * C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (N * K *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
struct ConvParams
{
ConvParams()
: num_dim_spatial(2),
N(128),
K(256),
C(192),
filter_spatial_lengths(2, 3),
input_spatial_lengths(2, 71),
conv_filter_strides(2, 2),
conv_filter_dilations(2, 1),
input_left_pads(2, 1),
input_right_pads(2, 1)
{
}
ck::index_t num_dim_spatial;
ck::index_t N;
ck::index_t K;
ck::index_t C;
std::vector<ck::index_t> filter_spatial_lengths;
std::vector<ck::index_t> input_spatial_lengths;
std::vector<ck::index_t> conv_filter_strides;
std::vector<ck::index_t> conv_filter_dilations;
std::vector<ck::index_t> input_left_pads;
std::vector<ck::index_t> input_right_pads;
std::vector<ck::index_t> GetOutputSpatialLengths() const
{
std::vector<ck::index_t> out_spatial_len(num_dim_spatial, 0);
for(ck::index_t i = 0; i < num_dim_spatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t idx_eff =
(filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
out_spatial_len[i] =
(input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) /
conv_filter_strides[i] +
1;
}
return out_spatial_len;
}
};
/**
* @brief Gets the host tensor descriptor.
*
* @param[in] dims The tensor dimensions lengths. Always in NCHW format.
* @param[in] layout The tensor data layout.
*
* @tparam TensorLayout Layout type.
*
* @return The host tensor descriptor object.
*/
template <typename TensorLayout>
HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dims,
const TensorLayout& layout)
{
std::size_t C = dims[1];
// 1D
if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKW>::value)
{
return HostTensorDescriptor(dims, std::vector<std::size_t>({C * dims[2], dims[2], 1}));
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NWK>::value)
{
return HostTensorDescriptor(dims, std::vector<std::size_t>({C * dims[2], 1, C}));
}
// 2D
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCHW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCYX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKHW>::value)
{
return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KYXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWK>::value)
{
return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C});
}
std::stringstream err_msg;
err_msg << "Unsupported data layout provided: " << layout << "!";
throw std::runtime_error(err_msg.str());
}
} // namespace conv_util
} // namespace ck
#endif
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace ck {
namespace tensor_operation {
struct ConvolutionUtility
{
static std::vector<ck::index_t>
ComputeOutputSpatialLengths(std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_strides,
std::vector<ck::index_t> conv_dilations,
std::vector<ck::index_t> in_left_pads,
std::vector<ck::index_t> in_right_pads)
{
if(input_spatial_lengths.size() == 2)
{
assert(filter_spatial_lengths.size() == 2);
assert(conv_strides.size() == 2);
assert(conv_dilations.size() == 2);
assert(in_left_pads.size() == 2);
assert(in_right_pads.size() == 2);
const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho =
(Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1;
const index_t Wo =
(Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1;
return {Ho, Wo};
}
else if(input_spatial_lengths.size() == 3)
{
assert(filter_spatial_lengths.size() == 3);
assert(conv_strides.size() == 3);
assert(conv_dilations.size() == 3);
assert(in_left_pads.size() == 3);
assert(in_right_pads.size() == 3);
const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do =
(Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1;
const index_t Ho =
(Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1;
const index_t Wo =
(Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1;
return {Do, Ho, Wo};
}
else
{
return {};
}
}
};
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -248,7 +248,7 @@ struct DeviceBatchedGemmXdl ...@@ -248,7 +248,7 @@ struct DeviceBatchedGemmXdl
c_grid_desc_g_m_n_); c_grid_desc_g_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseBatchedGemm::MakeBlock2CTileMap(c_grid_desc_g_m_n_, M01, N01); GridwiseBatchedGemm::MakeDefaultBlock2CTileMap(c_grid_desc_g_m_n_, M01, N01);
} }
} }
...@@ -261,7 +261,7 @@ struct DeviceBatchedGemmXdl ...@@ -261,7 +261,7 @@ struct DeviceBatchedGemmXdl
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseBatchedGemm::CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseBatchedGemm::CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseBatchedGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseBatchedGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -327,7 +327,7 @@ struct DeviceBatchedGemmXdl ...@@ -327,7 +327,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseBatchedGemm::Block2CTileMap>, remove_reference_t<typename GridwiseBatchedGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
...@@ -359,7 +359,7 @@ struct DeviceBatchedGemmXdl ...@@ -359,7 +359,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseBatchedGemm::Block2CTileMap>, remove_reference_t<typename GridwiseBatchedGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment