Commit c7d59414 authored by Chao Liu's avatar Chao Liu
Browse files

debugging

parent ea3feee5
...@@ -27,35 +27,60 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -27,35 +27,60 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
struct AddAddFastGelu
{
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ void
operator()(ck::half_t& y, const float& x0, const ck::half_t& x1, const ck::half_t& x2) const
{
#if 0
const float x = x0 + x1 + x2;
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
#else
const float x = x0 + x2;
y = x;
#endif
}
};
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AcEDataType = F32; using AccDataType = F32;
using DsDataType = ck::Tuple<>; using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using CLayout = Row; using ELayout = Row;
using AcEDataType = F32; using AccDataType = F32;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = FastGelu; #if 0
using CDEElementOp = FastGelu;
#else
using CDEElementOp = AddAddFastGelu;
#endif
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| NumDTensor| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| | Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F32, F32, 0, ck::Tuple<>, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F32, F32, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, EDataType, AcEDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -69,7 +94,7 @@ int main(int argc, char* argv[]) ...@@ -69,7 +94,7 @@ int main(int argc, char* argv[])
ck::index_t StrideA = 4096; ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideE = 4096;
if(argc == 1) if(argc == 1)
{ {
...@@ -93,14 +118,14 @@ int main(int argc, char* argv[]) ...@@ -93,14 +118,14 @@ int main(int argc, char* argv[])
StrideA = std::stoi(argv[7]); StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]); StrideE = std::stoi(argv[9]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0); exit(0);
} }
...@@ -120,57 +145,87 @@ int main(int argc, char* argv[]) ...@@ -120,57 +145,87 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<EDataType> d0_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{}));
Tensor<EDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<EDataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
#if 0
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5});
#else
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
#endif
break; break;
case 2: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; #if 0
default: d0_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0});
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); #else
d0_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
#endif
} }
std::cout << "a: " << a_m_k.mDesc.GetElementSpace() << std::endl;
std::cout << "b: " << b_k_n.mDesc.GetElementSpace() << std::endl;
std::cout << "d0: " << d0_m_n.mDesc.GetElementSpace() << std::endl;
std::cout << "d1: " << d1_m_n.mDesc.GetElementSpace() << std::endl;
std::cout << "e: " << e_m_n_device_result.mDesc.GetElementSpace() << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(EDataType) * c_m_n_device_result.mDesc.GetElementSpace()); #if 1
DeviceMem d0_m_n_device_buf(sizeof(EDataType) * d0_m_n.mDesc.GetElementSpace());
#else
DeviceMem d0_m_n_device_buf(sizeof(EDataType) * d1_m_n.mDesc.GetElementSpace());
#endif
DeviceMem d1_m_n_device_buf(sizeof(EDataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), auto argument =
b_k_n_device_buf.GetDeviceBuffer(), gemm.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>(), b_k_n_device_buf.GetDeviceBuffer(),
c_m_n_device_buf.GetDeviceBuffer(), std::array<const void*, 2>{{d0_m_n_device_buf.GetDeviceBuffer(),
M, d1_m_n_device_buf.GetDeviceBuffer()}},
N, e_m_n_device_buf.GetDeviceBuffer(),
K, M,
StrideA, N,
StrideB, K,
std::array<ck::index_t, 0>{}, StrideA,
StrideC, StrideB,
a_element_op, std::array<ck::index_t, 2>{{0, StrideE}},
b_element_op, StrideE,
c_element_op); a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -192,19 +247,54 @@ int main(int argc, char* argv[]) ...@@ -192,19 +247,54 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
auto ref_gemm = ReferenceGemmInstance{}; #if 1
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
#else
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, CDEElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; #endif
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer_v7.hpp"
namespace ck { namespace ck {
...@@ -111,6 +112,20 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -111,6 +112,20 @@ struct ThreadGroupTensorSliceTransfer_v6r1
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
#if 0
ThreadwiseTensorSliceTransfer_v7<Tuple<SrcData>,
Tuple<DstData>,
Tuple<SrcDesc>,
Tuple<DstDesc>,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
DstInMemOp>;
#else
ThreadwiseTensorSliceTransfer_v6r1<SrcData, ThreadwiseTensorSliceTransfer_v6r1<SrcData,
DstData, DstData,
SrcDesc, SrcDesc,
...@@ -123,8 +138,9 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -123,8 +138,9 @@ struct ThreadGroupTensorSliceTransfer_v6r1
DstInMemOp, DstInMemOp,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
#endif
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
} // namespace ck } // namespace ck
...@@ -98,7 +98,6 @@ template <typename ALayout, ...@@ -98,7 +98,6 @@ template <typename ALayout,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
index_t NumDTensor,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -135,13 +134,15 @@ template <typename ALayout, ...@@ -135,13 +134,15 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType::Size(),
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -419,6 +420,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor, ...@@ -419,6 +420,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -501,7 +503,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor, ...@@ -501,7 +503,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor,
e_grid_desc_m_n_); e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>; using DDataType = remove_cvref_t<decltype(DsDataType{}.At(i))>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
...@@ -527,16 +529,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor, ...@@ -527,16 +529,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor,
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
using DsGridPointer = decltype(MakeDsGridPointer());
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
ck::StaticallyIndexedArray< StaticallyIndexedArray<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor> NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
...@@ -556,6 +556,38 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor, ...@@ -556,6 +556,38 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 1
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I0)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I1)
<< "}" << std::endl;
std::cout << "arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I0)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I1)
<< "}" << std::endl;
std::cout << "p_ds_grid{ " << arg.p_ds_grid_[I0] << ", " << arg.p_ds_grid_[I1]
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
...@@ -576,7 +608,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor, ...@@ -576,7 +608,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<NumDTensor,
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename Argument::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -16,6 +17,7 @@ namespace ck { ...@@ -16,6 +17,7 @@ namespace ck {
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename DsDataType,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -58,6 +60,8 @@ template <typename FloatAB, ...@@ -58,6 +60,8 @@ template <typename FloatAB,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -109,6 +113,18 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -109,6 +113,18 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<decltype(DsDataType{}.At(i))>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -215,26 +231,41 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -215,26 +231,41 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
const FloatAB* __restrict__ p_b_grid, Run(const FloatAB* __restrict__ p_a_grid,
Tuple<> p_ds_grid, // FIXME const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, DsGridPointer p_ds_grid,
void* __restrict__ p_shared, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation& a_element_op, void* __restrict__ p_shared,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const CElementwiseOperation& c_element_op, const BElementwiseOperation& b_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const CElementwiseOperation& c_element_op,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
Tuple<> ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const StaticallyIndexedArray<CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
c_grid_desc_mblock_mperblock_nblock_nperblock, NumDTensor>&
const Block2CTileMap& block_2_ctile_map) ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: use tuple
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -505,6 +536,42 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -505,6 +536,42 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
#if 1
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
#else
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
...@@ -515,20 +582,21 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -515,20 +582,21 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename Src0Data,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op}; c_element_op};
#endif
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
...@@ -571,15 +639,35 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -571,15 +639,35 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global // each block copy its data from LDS to global
#if 1
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
#else
c_shuffle_block_copy_lds_to_global.Run( c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
#endif
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
#if 1
// move on Ds
c_shuffle_block_copy_lds_to_global.MoveSrc1SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step);
c_shuffle_block_copy_lds_to_global.MoveSrc2SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step);
#endif
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
......
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template <typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool SrcResetCoordinateAfterRun,
bool DstResetCoordinateAfterRun,
InMemoryDataOperationEnum... DstInMemOps>
struct ThreadwiseTensorSliceTransfer_v7
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = SliceLengths::Size();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
constexpr auto MakeCoordiantes(const Descs& descs, const Indices& indices)
{
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
Number<Descs::Size()>{});
}
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve =
SpaceFillingCurve<SliceLengths, DimAccessOrder, remove_cv_t<decltype(scalar_per_access)>>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
}
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetSrcSliceOrigin(const SrcDescs& src_descs,
const Indices& src_slice_origin_idxs)
{
static_for<0, nSrc, 1>{}([&](auto i) {
src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
});
}
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetDstSliceOrigin(const DstDescs& dst_descs,
const Indices& dst_slice_origin_idxs)
{
static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
});
}
template <typename SrcBuffers,
typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size()>,
bool = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers& dst_bufs)
{
auto generate_vectors = [&](auto data_types) {
return generate_tuple([&](auto i) {
using DataType = decltype(data_types[i]);
return vector_type_maker_t<DataType, ScalarPerVector>{};
});
};
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors(SrcDatas{});
auto dst_vectors = generate_vectors(DstDatas{});
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cv_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
src_vectors(i) = src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid);
});
// apply pointwise function
// FIXME: support tuple of arbitary size
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
using SrcData0 = decltype(SrcDatas{}.At[I0]);
using DstData0 = decltype(DstDatas{}.At[I0]);
element_op_(dst_vectors[I0].template AsType<DstData0>()(i),
src_vectors[I0].template AsType<SrcData0>()[i]);
});
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cv_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]);
constexpr auto DstInMemOp = make_tuple(DstInMemOps...)[i];
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(),
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i],
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
});
static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step));
});
}
});
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
static_for<0, nSrc, 1>{}([&](auto i) {
const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
});
}
if constexpr(DstResetCoordinateAfterRun)
{
static_for<0, nDst, 1>{}([&](auto i) {
const auto dst_reset_step =
make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
});
}
}
__device__ static constexpr auto GetCoordinateResetStep()
{
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
static_for<0, nSrc, 1>{}([&](auto i) {
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[i], adjusted_step_idx);
move_tensor_coordinate(src_descs[i], src_coords_(i), adjusted_step);
});
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
static_for<0, nDst, 1>{}([&](auto i) {
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[i], adjusted_step_idx);
move_tensor_coordinate(dst_descs[i], dst_coords_(i), adjusted_step);
});
}
private:
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
};
} // namespace ck
#ifndef CK_ENABLE_IF_HPP #pragma once
#define CK_ENABLE_IF_HPP
namespace ck { namespace ck {
...@@ -10,4 +9,3 @@ template <bool B, typename T = void> ...@@ -10,4 +9,3 @@ template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type; using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment