Commit aaa89914 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into conv_splitk_f32

parents f8804804 acbd7bd7
#ifndef CK_STATIC_TENSOR_HPP #ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace ck { namespace ck {
// StaticTensor for Scalar // StaticTensor for Scalar
...@@ -17,10 +15,10 @@ struct StaticTensor ...@@ -17,10 +15,10 @@ struct StaticTensor
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {} __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {}
__host__ __device__ constexpr StaticTensor(T invalid_element_value) __host__ __device__ constexpr StaticTensor(T invalid_element_value)
: invalid_element_value_{invalid_element_value} : invalid_element_scalar_value_{invalid_element_value}
{ {
} }
...@@ -44,11 +42,11 @@ struct StaticTensor ...@@ -44,11 +42,11 @@ struct StaticTensor
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return T{0}; return zero_scalar_value_;
} }
else else
{ {
return invalid_element_value_; return invalid_element_scalar_value_;
} }
} }
} }
...@@ -71,12 +69,14 @@ struct StaticTensor ...@@ -71,12 +69,14 @@ struct StaticTensor
} }
else else
{ {
return ignore; return ignored_element_scalar_;
} }
} }
StaticBuffer<AddressSpace, T, element_space_size_, true> data_; StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
T invalid_element_value_ = T{0}; static constexpr T zero_scalar_value_ = T{0};
const T invalid_element_scalar_value_;
T ignored_element_scalar_;
}; };
// StaticTensor for vector // StaticTensor for vector
...@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using V = vector_type<S, ScalarPerVector>; using V = vector_type<S, ScalarPerVector>;
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {} __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
: invalid_element_scalar_value_{0}
{
}
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
: invalid_element_value_{invalid_element_value} : invalid_element_scalar_value_{invalid_element_value}
{ {
} }
...@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return S{0}; return zero_scalar_value_;
} }
else else
{ {
return invalid_element_value_; return invalid_element_scalar_value_;
} }
} }
} }
...@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
} }
else else
{ {
return ignore; return ignored_element_scalar_;
} }
} }
...@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else else
{ {
// TODO: is this right way to initialize a vector? // TODO: is this right way to initialize a vector?
return X{invalid_element_value_}; return X{invalid_element_scalar_value_};
} }
} }
} }
...@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
} }
StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_; StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
S invalid_element_value_ = S{0}; static constexpr S zero_scalar_value_ = S{0};
const S invalid_element_scalar_value_ = S{0};
S ignored_element_scalar_;
}; };
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum_t AddressSpace,
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v3r2.hpp" #include "threadwise_tensor_slice_transfer_v3r1.hpp"
namespace ck { namespace ck {
...@@ -15,9 +15,9 @@ namespace ck { ...@@ -15,9 +15,9 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -34,35 +34,38 @@ template <index_t BlockSize, ...@@ -34,35 +34,38 @@ template <index_t BlockSize,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v4 struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4( __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const SrcElementwiseOperation& src_element_op) const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc, : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op,
dst_desc, dst_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op) dst_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
...@@ -74,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -74,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin); src_block_slice_origin + thread_data_idx_begin);
...@@ -114,6 +117,16 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -114,6 +117,16 @@ struct BlockwiseTensorSliceTransfer_v4
} }
} }
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
...@@ -152,8 +165,9 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -152,8 +165,9 @@ struct BlockwiseTensorSliceTransfer_v4
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
SrcElementwiseOperation, SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v5r1.hpp"
namespace ck { namespace ck {
...@@ -31,13 +31,13 @@ template <index_t BlockSize, ...@@ -31,13 +31,13 @@ template <index_t BlockSize,
typename DstVectorTensorContiguousDimOrder, typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v4r1 struct BlockwiseTensorSliceTransfer_v5r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin)
...@@ -134,7 +134,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -134,7 +134,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v5r1<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r1<SrcData,
DstData,
SrcDesc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc1SliceOrigin(
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename Src0Buffer, typename Src1Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
Src1Data,
DstData,
Src0Desc,
Src1Desc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r3.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename Src2Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
make_zero_multi_index<nDim>(),
src2_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc1SliceOrigin(
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc2SliceOrigin(
src2_desc, src2_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const Src2Desc& src2_desc,
const Src2Buffer& src2_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r3<Src0Data,
Src1Data,
Src2Data,
DstData,
Src0Desc,
Src1Desc,
Src2Desc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferSrc2ResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
y = x;
}
// TODO remove this
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct AddRelu
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const
{
T a = x0 + x1;
y = a > 0 ? a : 0;
}
// TODO remove this
template <typename T1>
__host__ constexpr float operator()(float v0, T1 v1) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
return c;
}
// TODO remove this
template <typename T1>
__device__ constexpr float operator()(float v0, T1 v1) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float b = v1 + v0;
float c = b > 0 ? b : 0;
return c;
#endif
}
};
struct AddReluAdd
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1, const T& x2) const
{
T a = x0 + x1;
T b = a > 0 ? a : 0;
y = b + x2;
}
// TODO remove this
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
// TODO remove this
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float b = v1 + v2;
float c = (v0 > -v1) ? b + v0 : v2;
return c;
#endif
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct AddLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -381,7 +381,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -381,7 +381,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
...@@ -405,7 +405,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -405,7 +405,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
make_multi_index(0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
...@@ -380,7 +380,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -380,7 +380,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
...@@ -404,7 +404,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -404,7 +404,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
......
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -22,7 +21,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +21,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename CBlockClusterAdaptor,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -37,7 +36,7 @@ __global__ void ...@@ -37,7 +36,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -54,7 +53,7 @@ __global__ void ...@@ -54,7 +53,7 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -156,11 +155,6 @@ template <index_t BlockSize, ...@@ -156,11 +155,6 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat, bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM, bool ABlockLdsExtraM,
bool BBlockLdsExtraN> bool BBlockLdsExtraN>
...@@ -401,6 +395,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -401,6 +395,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0]; const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
...@@ -477,61 +472,65 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -477,61 +472,65 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
InMemoryDataOperationEnum_t::Set, ck::tensor_operation::element_wise::PassThrough,
Sequence<1, K0PerBlock, MPerBlock, K1>, InMemoryDataOperationEnum_t::Set,
ABlockTransferThreadSliceLengths_K0_M_K1, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterLengths_K0_M_K1,
FloatAB, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
decltype(a_b_k0_m_k1_grid_desc), FloatAB,
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_grid_desc),
ABlockTransferSrcAccessOrder, decltype(a_b_k0_m_k1_block_desc),
Sequence<0, 2, 1, 3>, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, Sequence<0, 2, 1, 3>,
3, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, 3,
ABlockTransferDstScalarPerVector_K1, ABlockTransferSrcScalarPerVector,
1, ABlockTransferDstScalarPerVector_K1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, 1,
true>( AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc, a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
a_element_op); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
InMemoryDataOperationEnum_t::Set, ck::tensor_operation::element_wise::PassThrough,
Sequence<1, K0PerBlock, NPerBlock, K1>, InMemoryDataOperationEnum_t::Set,
BBlockTransferThreadSliceLengths_K0_N_K1, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterLengths_K0_N_K1,
FloatAB, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
decltype(b_b_k0_n_k1_grid_desc), FloatAB,
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_grid_desc),
BBlockTransferSrcAccessOrder, decltype(b_b_k0_n_k1_block_desc),
Sequence<0, 2, 1, 3>, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, Sequence<0, 2, 1, 3>,
3, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, 3,
BBlockTransferDstScalarPerVector_K1, BBlockTransferSrcScalarPerVector,
1, BBlockTransferDstScalarPerVector_K1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, 1,
true>( BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc, b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
b_element_op); ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -565,15 +564,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -565,15 +564,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -581,33 +571,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -581,33 +571,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
} }
// Initialize C
c_thread_buf.Clear();
// main body // main body
index_t k_block_data_begin = 0;
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0;
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -656,8 +644,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -656,8 +644,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -706,8 +692,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -706,8 +692,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
} }
} }
}; // namespace ck }; // namespace ck
......
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer_v1r4.hpp" #include "threadwise_tensor_slice_transfer_v1r4.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -88,7 +87,6 @@ template <index_t BlockSize, ...@@ -88,7 +87,6 @@ template <index_t BlockSize,
index_t K1Value, index_t K1Value,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -96,7 +94,7 @@ template <index_t BlockSize, ...@@ -96,7 +94,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
...@@ -104,17 +102,10 @@ template <index_t BlockSize, ...@@ -104,17 +102,10 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -410,59 +401,63 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -410,59 +401,63 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
InMemoryDataOperationEnum_t::Set, ck::tensor_operation::element_wise::PassThrough,
Sequence<K0PerBlock, MPerBlock, K1>, InMemoryDataOperationEnum_t::Set,
ABlockTransferThreadSliceLengths_K0_M_K1, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1, true>(
make_multi_index(0, m_block_data_idx_on_grid, 0), a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0), a_element_op,
a_element_op); a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
InMemoryDataOperationEnum_t::Set, ck::tensor_operation::element_wise::PassThrough,
Sequence<K0PerBlock, NPerBlock, K1>, InMemoryDataOperationEnum_t::Set,
BBlockTransferThreadSliceLengths_K0_N_K1, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1, true>(
make_multi_index(0, n_block_data_idx_on_grid, 0), b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0), b_element_op,
b_element_op); b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -496,15 +491,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -496,15 +491,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -512,34 +498,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -512,34 +498,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
} }
// main body // Initialize C
index_t k0_block_data_begin = 0; c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0;
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -588,8 +571,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -588,8 +571,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -642,14 +623,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -642,14 +623,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
c_thread_buf, c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_buf, c0_grid_buf,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_buf); c1_grid_buf);
} }
} }
}; // namespace ck };
} // namespace ck } // namespace ck
#endif #endif
...@@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{}; constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
...@@ -326,7 +326,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -326,7 +326,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
...@@ -506,7 +506,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -506,7 +506,7 @@ struct ThreadwiseTensorSliceTransfer_v2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0]; index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
}); });
...@@ -638,7 +638,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -638,7 +638,7 @@ struct ThreadwiseTensorSliceTransfer_v2
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
...@@ -835,7 +835,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -835,7 +835,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0]; index_t tmp = ordered_src_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
}); });
...@@ -992,7 +992,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -992,7 +992,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
...@@ -1136,7 +1136,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1136,7 +1136,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1; index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
}); });
...@@ -1196,7 +1196,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1196,7 +1196,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1; index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
}); });
......
...@@ -116,9 +116,6 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -116,9 +116,6 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{}; constexpr auto dim_access_order = DimAccessOrder{};
...@@ -141,7 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -141,7 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make forward steps: dst0 // make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst0_forward_steps = generate_tuple( const auto dst0_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -157,7 +155,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -157,7 +155,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make forward steps: dst1 // make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst1_forward_steps = generate_tuple( const auto dst1_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -187,7 +186,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -187,7 +186,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make backward steps: dst0 // make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst0_backward_steps = generate_tuple( const auto dst0_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -203,7 +203,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -203,7 +203,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number<nDim>{}); Number<nDim>{});
// make backward steps: dst1 // make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this // TODO: fix this
const auto dst1_backward_steps = generate_tuple( const auto dst1_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -229,7 +230,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -229,7 +230,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0]; index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
}); });
...@@ -397,14 +398,12 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -397,14 +398,12 @@ struct ThreadwiseTensorSliceTransfer_v1r4
typename SrcBuffer, typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename Dst0Buffer, typename Dst0Buffer,
typename Dst1Buffer, typename Dst1Buffer>
typename DstStepHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc, const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf, const Dst0Buffer& dst0_buf,
const Dst1Desc& dst1_desc, const Dst1Desc& dst1_desc,
...@@ -427,7 +426,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -427,7 +426,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
src_buf, src_buf,
dst_desc, dst_desc,
dst_buf, dst_buf,
dst_step_hacks, f_step_hacks(dst_desc),
dst0_desc, dst0_desc,
dst0_buf, dst0_buf,
f_step_hacks(dst0_desc), f_step_hacks(dst0_desc),
...@@ -461,7 +460,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -461,7 +460,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1; index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment