Commit ff4f8ba8 authored by Chao Liu's avatar Chao Liu
Browse files

refactoring; add readme

parent 25e35b59
# Instructions for ```example_gemm_bias_add_fastgelu_xdl_fp16```
## Run ```example_gemm_bias_add_fastgelu_xdl_fp16```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: arg3: time kernel (0=no, 1=yes)
./bin/example_gemm_bias_add_fastgelu_xdl_fp16 1 1 1
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
```
...@@ -113,7 +113,7 @@ int main(int argc, char* argv[]) ...@@ -113,7 +113,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0); exit(0);
} }
...@@ -161,12 +161,6 @@ int main(int argc, char* argv[]) ...@@ -161,12 +161,6 @@ int main(int argc, char* argv[])
d1_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0});
} }
std::cout << "a: " << a_m_k.mDesc.GetElementSpace() << std::endl;
std::cout << "b: " << b_k_n.mDesc.GetElementSpace() << std::endl;
std::cout << "d0: " << d0_m_n.mDesc.GetElementSpace() << std::endl;
std::cout << "d1: " << d1_m_n.mDesc.GetElementSpace() << std::endl;
std::cout << "e: " << e_m_n_device_result.mDesc.GetElementSpace() << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
......
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v7.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <typename ThreadGroup,
typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename Src2Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct ThreadGroupTensorSliceTransfer_v7
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(tie(src0_desc, src1_desc, src2_desc),
make_tuple(make_zero_multi_index<nDim>(),
make_zero_multi_index<nDim>(),
make_zero_multi_index<nDim>()),
tie(dst_desc),
make_tuple(make_zero_multi_index<nDim>()),
element_op)
{
static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<Src2Desc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(
tie(src0_desc, src1_desc, src2_desc),
make_tuple(src0_block_slice_origin + thread_data_idx_begin,
src1_block_slice_origin + thread_data_idx_begin,
src2_block_slice_origin + thread_data_idx_begin));
threadwise_transfer_.SetDstSliceOrigin(
tie(dst_desc), make_tuple(dst_block_slice_origin + thread_data_idx_begin));
}
}
template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const Src2Desc& src2_desc,
const Src2Buffer& src2_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(
tie(src0_desc, src1_desc, src2_desc),
tie(src0_buf, src1_buf, src2_buf),
tie(dst_desc),
tie(dst_buf));
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7<
Tuple<remove_cvref_t<Src0Data>, remove_cvref_t<Src1Data>, remove_cvref_t<Src2Data>>,
Tuple<remove_cvref_t<DstData>>,
Tuple<remove_reference_t<Src0Desc>&,
remove_reference_t<Src1Desc>&,
remove_reference_t<Src2Desc>&>,
Tuple<remove_reference_t<DstDesc>&>,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
Sequence<ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferSrc2ResetCoordinateAfterRun>,
Sequence<ThreadTransferDstResetCoordinateAfterRun>,
DstInMemOp>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
...@@ -558,46 +558,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -558,46 +558,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 1
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I0)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I1)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I2)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I3)
<< "}" << std::endl;
std::cout << "arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I0)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I1)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I2)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I3)
<< "}" << std::endl;
std::cout << "p_ds_grid{ " << arg.p_ds_grid_[I0] << ", " << arg.p_ds_grid_[I1]
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp" #include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -223,7 +223,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -223,7 +223,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return e_grid_desc_mblock_mperblock_nblock_nperblock; return e_grid_desc_mblock_mperblock_nblock_nperblock;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
...@@ -579,7 +579,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -579,7 +579,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
cde_element_op}; cde_element_op};
#else #else
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r1< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
CDEElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp, EGlobalMemoryDataOperation, // DstInMemOp,
...@@ -588,18 +588,28 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -588,18 +588,28 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data, FloatCShuffle, // typename Src0Data,
FloatE, // typename DstData, remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
cde_element_op}; cde_element_op};
...@@ -660,6 +670,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -660,6 +670,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf); e_grid_buf);
#endif #endif
......
...@@ -28,12 +28,14 @@ template <typename SrcDatas, ...@@ -28,12 +28,14 @@ template <typename SrcDatas,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorDim, index_t VectorDim,
index_t ScalarPerVector, index_t ScalarPerVector,
bool SrcResetCoordinateAfterRun, typename SrcResetCoordinateAfterRunFlags, // Sequence<...>
bool DstResetCoordinateAfterRun, typename DstResetCoordinateAfterRunFlags, // Sequence<...>
InMemoryDataOperationEnum... DstInMemOps> InMemoryDataOperationEnum... DstInMemOps>
struct ThreadwiseTensorSliceTransfer_v7 struct ThreadwiseTensorSliceTransfer_v7
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -46,7 +48,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -46,7 +48,7 @@ struct ThreadwiseTensorSliceTransfer_v7
template <typename Descs, template <typename Descs,
typename Indices, typename Indices,
enable_if_t<Descs::Size() == Indices::Size(), bool> = false> enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
constexpr auto MakeCoordiantes(const Descs& descs, const Indices& indices) static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
{ {
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
Number<Descs::Size()>{}); Number<Descs::Size()>{});
...@@ -95,22 +97,28 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -95,22 +97,28 @@ struct ThreadwiseTensorSliceTransfer_v7
}); });
} }
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename SrcBuffers, template <typename SrcBuffers,
typename DstBuffers, typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() && enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size()>, DstDescs::Size() == DstBuffers::Size(),
bool = false> bool> = false>
__device__ void Run(const SrcDescs& src_descs, __device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs, const SrcBuffers& src_bufs,
const DstDescs& dst_descs, const DstDescs& dst_descs,
DstBuffers& dst_bufs) const DstBuffers& dst_bufs)
{ {
auto generate_vectors = [&](auto data_types) { auto generate_vectors = [&](auto data_types) {
constexpr index_t num = data_types.Size();
return generate_tuple([&](auto i) { return generate_tuple([&](auto i) {
using DataType = decltype(data_types[i]); using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{}; return vector_type_maker_t<DataType, ScalarPerVector>{};
}); }, Number<num>{});
}; };
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
...@@ -122,7 +130,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -122,7 +130,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from src_bufs into src_vectors // copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cv_t<decltype(src_vectors[i])>::type; using src_vector_t = remove_cvref_t<typename decltype(src_vectors[i])::type>;
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
...@@ -135,11 +143,16 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -135,11 +143,16 @@ struct ThreadwiseTensorSliceTransfer_v7
// apply pointwise function // apply pointwise function
// FIXME: support tuple of arbitary size // FIXME: support tuple of arbitary size
static_for<0, ScalarPerVector, 1>{}([&](auto i) { static_for<0, ScalarPerVector, 1>{}([&](auto i) {
using SrcData0 = decltype(SrcDatas{}.At[I0]); using SrcData0 = remove_cvref_t<decltype(SrcDatas{}[I0])>;
using DstData0 = decltype(DstDatas{}.At[I0]); using SrcData1 = remove_cvref_t<decltype(SrcDatas{}[I1])>;
using SrcData2 = remove_cvref_t<decltype(SrcDatas{}[I2])>;
using DstData0 = remove_cvref_t<decltype(DstDatas{}[I0])>;
element_op_(dst_vectors[I0].template AsType<DstData0>()(i), element_op_(dst_vectors[I0].template AsType<DstData0>()(i),
src_vectors[I0].template AsType<SrcData0>()[i]); src_vectors[I0].template AsType<SrcData0>()[i],
src_vectors[I1].template AsType<SrcData1>()[i],
src_vectors[I2].template AsType<SrcData2>()[i]);
}); });
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
...@@ -178,25 +191,25 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -178,25 +191,25 @@ struct ThreadwiseTensorSliceTransfer_v7
}); });
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) static_for<0, nSrc, 1>{}([&](auto i) {
{ if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
static_for<0, nSrc, 1>{}([&](auto i) { {
const auto src_reset_step = const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep()); make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
}); }
} });
if constexpr(DstResetCoordinateAfterRun) static_for<0, nDst, 1>{}([&](auto i) {
{ if constexpr(DstResetCoordinateAfterRunFlags::At(i))
static_for<0, nDst, 1>{}([&](auto i) { {
const auto dst_reset_step = const auto dst_reset_step =
make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep()); make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
}); }
} });
} }
__device__ static constexpr auto GetCoordinateResetStep() __device__ static constexpr auto GetCoordinateResetStep()
...@@ -220,12 +233,13 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -220,12 +233,13 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx) const Index& src_slice_origin_step_idx)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRunFlags::At(i)
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[i], adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(src_descs[i], adjusted_step_idx);
...@@ -237,12 +251,13 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -237,12 +251,13 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
{ {
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRunFlags::At(i)
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[i], adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(dst_descs[i], adjusted_step_idx);
......
...@@ -6,6 +6,8 @@ namespace ck { ...@@ -6,6 +6,8 @@ namespace ck {
template <typename T> template <typename T>
union BufferResource union BufferResource
{ {
__device__ constexpr BufferResource() : content{} {}
// 128 bit SGPRs to supply buffer resource in buffer instructions // 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t content; int32x4_t content;
......
...@@ -17,14 +17,18 @@ struct TupleElementKey ...@@ -17,14 +17,18 @@ struct TupleElementKey
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElement struct TupleElementKeyData
{ {
__host__ __device__ constexpr TupleElement() = default; #if 0
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
__host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif
template < template <
typename T, typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, bool>::type = false> typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value, bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -32,20 +36,20 @@ struct TupleElement ...@@ -32,20 +36,20 @@ struct TupleElement
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x) __host__ __device__ constexpr const Data& get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x) __host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x)
{ {
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct // TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x) __host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
...@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs> ...@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
template <index_t... Is, typename... Xs> template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>... struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
...@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!is_same<remove_cvref_t<Y>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size"); "wrong! inconsistent size");
...@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{ {
return get_tuple_element<TupleElementKey<I>>(*this); return get_tuple_element_data<TupleElementKey<I>>(*this);
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>) __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{ {
return get_tuple_element<TupleElementKey<I>>(*this); return get_tuple_element_data<TupleElementKey<I>>(*this);
} }
}; };
...@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr const auto& At(Number<I>) const __host__ __device__ constexpr const auto& At(Number<I>) const
{ {
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementDataByKey(detail::TupleElementKey<I>{});
} }
// write access // write access
...@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr auto& At(Number<I>) __host__ __device__ constexpr auto& At(Number<I>)
{ {
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementDataByKey(detail::TupleElementKey<I>{});
} }
// read access // read access
...@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
}; };
template <>
struct Tuple<>
{
__host__ __device__ constexpr Tuple() = default;
__host__ __device__ static constexpr index_t Size() { return 0; }
template <typename T>
__host__ __device__ constexpr auto operator=(const T&)
{
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template<index_t I, typename TTuple>
struct tuple_element
{
using type = decltype(TTuple{}.At(Number<I>{}));
};
template<index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment