"...test_cli/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "8e981b3ac2277ac02a0590727ff2c76608c93c78"
Unverified Commit acbd7bd7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Fusion Conv+Bias+ReLU(+Add) (#62)

* fix relu

* clean up

* clean up

* adding 1x1 conv

* adding 1x1 conv

* added 1x1 conv

* refactor

* refactor

* refactor

* added profiler for conv+bias+relu+add

* clean up

* adding conv+bias+relu

* adding conv+bias+relu

* added conv+bias+relu

* Update README.md

* update cpu verification

* adding c shuffle

* update static_tensor for dealing with invalid element

* adding c shuffle

* debugging

* fix bug

* convert to fp16 before shuffle

* shuffle more than one M/NRepeat

* clean up

* remove coordinate step hack from GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1

* clean up

* remove coordinate step hack from all gridwise gemm xdl

* clean up coordinate step hack

* clean up coordinate step hack

* ThreadwiseTensorSliceTransfer_v3r2 support pointwise op on both src and dst

* adding output shuffle in conv+bias+relu+add

* update

* added conv+bias+relu+add with c shuffle

* added conv+bias+relu+add with c shuffle

* fix forward_sweep bugs in threadwise copy

* clean up

* refactor

* clean up

* clean up

* added conv_c_shuffle+bias_relu

* clean up

* added conv+bias+relu+atomic_add

* clean up

* clean up

* clean up

* clean up

* clean up

* clean up

* misc fixes; add 1x1 specialization

* clean up

* delete unused device op

* clean up

* add support for odd C value
parent a4f24233
#ifndef CK_STATIC_TENSOR_HPP #ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace ck { namespace ck {
// StaticTensor for Scalar // StaticTensor for Scalar
...@@ -17,10 +15,10 @@ struct StaticTensor ...@@ -17,10 +15,10 @@ struct StaticTensor
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {} __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {}
__host__ __device__ constexpr StaticTensor(T invalid_element_value) __host__ __device__ constexpr StaticTensor(T invalid_element_value)
: invalid_element_value_{invalid_element_value} : invalid_element_scalar_value_{invalid_element_value}
{ {
} }
...@@ -44,11 +42,11 @@ struct StaticTensor ...@@ -44,11 +42,11 @@ struct StaticTensor
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return T{0}; return zero_scalar_value_;
} }
else else
{ {
return invalid_element_value_; return invalid_element_scalar_value_;
} }
} }
} }
...@@ -71,12 +69,14 @@ struct StaticTensor ...@@ -71,12 +69,14 @@ struct StaticTensor
} }
else else
{ {
return ignore; return ignored_element_scalar_;
} }
} }
StaticBuffer<AddressSpace, T, element_space_size_, true> data_; StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
T invalid_element_value_ = T{0}; static constexpr T zero_scalar_value_ = T{0};
const T invalid_element_scalar_value_;
T ignored_element_scalar_;
}; };
// StaticTensor for vector // StaticTensor for vector
...@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using V = vector_type<S, ScalarPerVector>; using V = vector_type<S, ScalarPerVector>;
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {} __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
: invalid_element_scalar_value_{0}
{
}
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
: invalid_element_value_{invalid_element_value} : invalid_element_scalar_value_{invalid_element_value}
{ {
} }
...@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return S{0}; return zero_scalar_value_;
} }
else else
{ {
return invalid_element_value_; return invalid_element_scalar_value_;
} }
} }
} }
...@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
} }
else else
{ {
return ignore; return ignored_element_scalar_;
} }
} }
...@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else else
{ {
// TODO: is this right way to initialize a vector? // TODO: is this right way to initialize a vector?
return X{invalid_element_value_}; return X{invalid_element_scalar_value_};
} }
} }
} }
...@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
} }
StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_; StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
S invalid_element_value_ = S{0}; static constexpr S zero_scalar_value_ = S{0};
const S invalid_element_scalar_value_ = S{0};
S ignored_element_scalar_;
}; };
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum_t AddressSpace,
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v3r2.hpp" #include "threadwise_tensor_slice_transfer_v3r1.hpp"
namespace ck { namespace ck {
...@@ -15,9 +15,9 @@ namespace ck { ...@@ -15,9 +15,9 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -34,35 +34,38 @@ template <index_t BlockSize, ...@@ -34,35 +34,38 @@ template <index_t BlockSize,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v4 struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4( __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const SrcElementwiseOperation& src_element_op) const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc, : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op,
dst_desc, dst_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op) dst_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
...@@ -74,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -74,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin); src_block_slice_origin + thread_data_idx_begin);
...@@ -114,6 +117,16 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -114,6 +117,16 @@ struct BlockwiseTensorSliceTransfer_v4
} }
} }
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
...@@ -152,8 +165,9 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -152,8 +165,9 @@ struct BlockwiseTensorSliceTransfer_v4
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
SrcElementwiseOperation, SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v5r1.hpp"
namespace ck { namespace ck {
...@@ -31,13 +31,13 @@ template <index_t BlockSize, ...@@ -31,13 +31,13 @@ template <index_t BlockSize,
typename DstVectorTensorContiguousDimOrder, typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v4r1 struct BlockwiseTensorSliceTransfer_v5r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin)
...@@ -134,7 +134,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -134,7 +134,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v5r1<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r1<SrcData,
DstData,
SrcDesc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc1SliceOrigin(
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename Src0Buffer, typename Src1Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
Src1Data,
DstData,
Src0Desc,
Src1Desc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r3.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename Src2Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
make_zero_multi_index<nDim>(),
src2_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc1SliceOrigin(
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc2SliceOrigin(
src2_desc, src2_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const Src2Desc& src2_desc,
const Src2Buffer& src2_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r3<Src0Data,
Src1Data,
Src2Data,
DstData,
Src0Desc,
Src1Desc,
Src2Desc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferSrc2ResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
y = x;
}
// TODO remove this
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct AddRelu
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const
{
T a = x0 + x1;
y = a > 0 ? a : 0;
}
// TODO remove this
template <typename T1>
__host__ constexpr float operator()(float v0, T1 v1) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
return c;
}
// TODO remove this
template <typename T1>
__device__ constexpr float operator()(float v0, T1 v1) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float b = v1 + v0;
float c = b > 0 ? b : 0;
return c;
#endif
}
};
struct AddReluAdd
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1, const T& x2) const
{
T a = x0 + x1;
T b = a > 0 ? a : 0;
y = b + x2;
}
// TODO remove this
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
// TODO remove this
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float b = v1 + v2;
float c = (v0 > -v1) ? b + v0 : v2;
return c;
#endif
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct AddLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -381,7 +381,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -381,7 +381,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
...@@ -405,7 +405,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -405,7 +405,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
make_multi_index(0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
...@@ -380,7 +380,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -380,7 +380,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
...@@ -404,7 +404,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -404,7 +404,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
......
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -40,15 +39,12 @@ __global__ void ...@@ -40,15 +39,12 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -83,9 +79,6 @@ __global__ void ...@@ -83,9 +79,6 @@ __global__ void
const void CONSTANT* p_c_element_op, const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map) const void CONSTANT* p_block_2_ctile_map)
{ {
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_grid_desc_k0_m_k1 = *reinterpret_cast<const AGridDesc_K0_M_K1*>( const auto a_grid_desc_k0_m_k1 = *reinterpret_cast<const AGridDesc_K0_M_K1*>(
cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1)); cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1));
const auto b_grid_desc_k0_n_k1 = *reinterpret_cast<const BGridDesc_K0_N_K1*>( const auto b_grid_desc_k0_n_k1 = *reinterpret_cast<const BGridDesc_K0_N_K1*>(
...@@ -102,12 +95,12 @@ __global__ void ...@@ -102,12 +95,12 @@ __global__ void
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>( const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op)); cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -135,9 +128,8 @@ template <index_t BlockSize, ...@@ -135,9 +128,8 @@ template <index_t BlockSize,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t K1Value, index_t K1Value,
index_t MRepeat, index_t MXdlPerWave,
index_t NRepeat, index_t NXdlPerWave,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -145,7 +137,7 @@ template <index_t BlockSize, ...@@ -145,7 +137,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
...@@ -153,17 +145,10 @@ template <index_t BlockSize, ...@@ -153,17 +145,10 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -178,7 +163,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -178,7 +163,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -197,6 +182,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -197,6 +182,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
}(); }();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() { constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
...@@ -212,14 +204,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -212,14 +204,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
}(); }();
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -233,8 +236,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -233,8 +236,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0, (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
...@@ -324,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -324,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MXdlPerWave,
NRepeat, NXdlPerWave,
K1>; K1>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
...@@ -376,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -376,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -409,42 +412,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -409,42 +412,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() { constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() { constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -460,19 +439,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -460,19 +439,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1, true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
a_element_op); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -488,11 +469,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -488,11 +469,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1, true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
b_element_op); ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -510,68 +493,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -510,68 +493,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MXdlPerWave,
NRepeat, NXdlPerWave,
K1>{}; K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block; auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
FloatAB* p_b_block = p_shared_block + a_block_space_size; static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
} }
// main body // Initialize C
index_t k0_block_data_begin = 0;
c_thread_buf.Clear(); c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0;
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -619,8 +587,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -619,8 +587,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -668,11 +634,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -668,11 +634,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf, c_grid_buf);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
} }
} }
}; // namespace ck };
} // namespace ck } // namespace ck
#endif #endif
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -19,6 +18,9 @@ template <typename GridwiseGemm, ...@@ -19,6 +18,9 @@ template <typename GridwiseGemm,
typename ABK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor, typename CBlockClusterAdaptor,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -31,6 +33,9 @@ __global__ void ...@@ -31,6 +33,9 @@ __global__ void
const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -45,6 +50,9 @@ __global__ void ...@@ -45,6 +50,9 @@ __global__ void
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...@@ -129,11 +137,6 @@ template <index_t BlockSize, ...@@ -129,11 +137,6 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat, bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM, bool ABlockLdsExtraM,
bool BBlockLdsExtraN> bool BBlockLdsExtraN>
...@@ -371,6 +374,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -371,6 +374,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0]; const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
...@@ -447,7 +451,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -447,7 +451,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -469,12 +475,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -469,12 +475,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
true>( true>(
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc, a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
...@@ -496,8 +506,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -496,8 +506,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
true>( true>(
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc, b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -531,15 +543,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -531,15 +543,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -547,33 +550,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -547,33 +550,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
} }
// Initialize C
c_thread_buf.Clear();
// main body // main body
index_t k_block_data_begin = 0;
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0;
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -622,8 +623,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -622,8 +623,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -648,6 +647,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -648,6 +647,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
FloatC, FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>, Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
...@@ -664,14 +664,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -664,14 +664,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4], m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2])}; n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
} }
} }
}; // namespace ck }; // namespace ck
......
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer_v1r4.hpp" #include "threadwise_tensor_slice_transfer_v1r4.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -88,7 +87,6 @@ template <index_t BlockSize, ...@@ -88,7 +87,6 @@ template <index_t BlockSize,
index_t K1Value, index_t K1Value,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -96,7 +94,7 @@ template <index_t BlockSize, ...@@ -96,7 +94,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
...@@ -104,17 +102,10 @@ template <index_t BlockSize, ...@@ -104,17 +102,10 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -410,11 +401,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -410,11 +401,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -430,19 +421,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -430,19 +421,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1, true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
a_element_op); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -458,11 +451,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -458,11 +451,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1, true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
b_element_op); ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -496,15 +491,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -496,15 +491,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -512,34 +498,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -512,34 +498,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
} }
// main body // Initialize C
index_t k0_block_data_begin = 0; c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0;
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -588,8 +571,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -588,8 +571,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -642,14 +623,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -642,14 +623,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
c_thread_buf, c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_buf, c0_grid_buf,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_buf); c1_grid_buf);
} }
} }
}; // namespace ck };
} // namespace ck } // namespace ck
#endif #endif
...@@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
...@@ -326,7 +326,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -326,7 +326,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
...@@ -506,7 +506,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -506,7 +506,7 @@ struct ThreadwiseTensorSliceTransfer_v2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0]; index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
}); });
...@@ -638,7 +638,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -638,7 +638,7 @@ struct ThreadwiseTensorSliceTransfer_v2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
...@@ -835,7 +835,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -835,7 +835,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0]; index_t tmp = ordered_src_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
}); });
...@@ -992,7 +992,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -992,7 +992,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
...@@ -1136,7 +1136,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1136,7 +1136,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1; index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
}); });
...@@ -1196,7 +1196,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1196,7 +1196,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1; index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
}); });
......
...@@ -116,9 +116,6 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -116,9 +116,6 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{}; constexpr auto dim_access_order = DimAccessOrder{};
...@@ -141,7 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -141,7 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make forward steps: dst0 // make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst0_forward_steps = generate_tuple( const auto dst0_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -157,7 +155,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -157,7 +155,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make forward steps: dst1 // make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst1_forward_steps = generate_tuple( const auto dst1_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -187,7 +186,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -187,7 +186,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make backward steps: dst0 // make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst0_backward_steps = generate_tuple( const auto dst0_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -203,7 +203,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -203,7 +203,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make backward steps: dst1 // make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst1_backward_steps = generate_tuple( const auto dst1_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -229,7 +230,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -229,7 +230,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0]; index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
}); });
...@@ -397,14 +398,12 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -397,14 +398,12 @@ struct ThreadwiseTensorSliceTransfer_v1r4
typename SrcBuffer, typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename Dst0Buffer, typename Dst0Buffer,
typename Dst1Buffer, typename Dst1Buffer>
typename DstStepHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc, const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf, const Dst0Buffer& dst0_buf,
const Dst1Desc& dst1_desc, const Dst1Desc& dst1_desc,
...@@ -427,7 +426,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -427,7 +426,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
src_buf, src_buf,
dst_desc, dst_desc,
dst_buf, dst_buf,
dst_step_hacks, f_step_hacks(dst_desc),
dst0_desc, dst0_desc,
dst0_buf, dst0_buf,
f_step_hacks(dst0_desc), f_step_hacks(dst0_desc),
...@@ -461,7 +460,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -461,7 +460,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
......
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP #ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP #define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst ...@@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 4. Use thread buffer // 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData, typename SrcData,
typename DstData, typename DstData,
...@@ -66,7 +67,7 @@ template <typename SliceLengths, ...@@ -66,7 +67,7 @@ template <typename SliceLengths,
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to // RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation // save addr computation
struct ThreadwiseTensorSliceTransfer_v3r2 struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -77,15 +78,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -77,15 +78,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin, const Index& dst_slice_origin,
const SrcElementwiseOperation& src_element_op) const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op) src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{ {
} }
...@@ -165,7 +168,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -165,7 +168,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0]; index_t tmp = ordered_src_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
}); });
...@@ -412,7 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -412,7 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
...@@ -442,13 +445,24 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -442,13 +445,24 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_t = typename vector_type_maker_t<DstData, DstScalarPerVector>::type; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ to dst_buf // copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
// apply DstElementwiseOperation on dst_vector_container
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
dst_vector_container.template AsType<DstData>()(i) =
dst_element_op_(dst_vector_container.template AsType<DstData>()[i]);
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)); dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -498,7 +512,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -498,7 +512,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template <typename SrcBuffer> template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{ {
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); constexpr index_t ntransform_src = remove_cvref_t<SrcDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
...@@ -512,7 +526,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -512,7 +526,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template <typename DstBuffer> template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); // TODO: why need remove_cvref_t ?
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
...@@ -548,7 +563,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -548,7 +563,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1; index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
}); });
...@@ -608,7 +623,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -608,7 +623,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1; index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
}); });
...@@ -811,6 +826,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -811,6 +826,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_; const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
}; };
} // namespace ck } // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment