Commit 460b059e authored by Chao Liu's avatar Chao Liu
Browse files

add KNN, KKN, MNN, MKN layout

parent 7fd0e649
...@@ -42,14 +42,40 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -42,14 +42,40 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device:: // Fast changing dimension in A/B/C are K/N/N dimensions
using ContractionInstanceKNN = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
// Fast changing dimension in A/B/C are K/K/N dimensions
using ContractionInstanceKKN = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
// Fast changing dimension in A/B/C are M/N/N dimensions
using ContractionInstanceMNN = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
// Fast changing dimension in A/B/C are M/K/N dimensions
using ContractionInstanceMKN = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on // clang-format on
using ContractionInstance = ContractionInstanceMKN;
template <typename T, typename Range> template <typename T, typename Range>
void LogRangeToFile(std::ofstream& fs, Range&& range, std::string delim) void LogRangeToFile(std::ofstream& fs, Range&& range, std::string delim)
...@@ -242,6 +268,7 @@ int main(int argc, char* argv[]) ...@@ -242,6 +268,7 @@ int main(int argc, char* argv[])
#if 0 #if 0
// fast changing dimension: K/K/N
// a[m0, m1, k0, k1] // a[m0, m1, k0, k1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64}; std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
//std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1}; //std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
...@@ -251,19 +278,52 @@ int main(int argc, char* argv[]) ...@@ -251,19 +278,52 @@ int main(int argc, char* argv[])
// c[m0, m1, n0, n1] // c[m0, m1, n0, n1]
std::vector<ck::index_t> c_ms_ns_lengths{30, 128, 32, 64}; std::vector<ck::index_t> c_ms_ns_lengths{30, 128, 32, 64};
//std::vector<ck::index_t> c_ms_ns_strides{524288, 4096, 128, 1}; //std::vector<ck::index_t> c_ms_ns_strides{524288, 4096, 128, 1};
#else #elif 0
// fast changing dimension: K/N/N
// a[m0, m1, k0, k1] // a[m0, m1, k0, k1]
std::vector<ck::index_t> a_ms_ks_lengths{5,6,3,4}; std::vector<ck::index_t> a_ms_ks_lengths{5,6,3,4};
//std::vector<ck::index_t> a_ms_ks_strides{108,20,16,1}; std::vector<ck::index_t> a_ms_ks_strides{108,20,16,1};
// b[k0, k1, n0, n1] // b[k0, k1, n0, n1]
std::vector<ck::index_t> b_ks_ns_lengths{3,4,3,4}; std::vector<ck::index_t> b_ks_ns_lengths{3,4,3,4};
//std::vector<ck::index_t> b_ks_ns_strides{16,1,108,20}; std::vector<ck::index_t> b_ks_ns_strides{48,12,4,1};
// c[m0, m1, n0, n1] // c[m0, m1, n0, n1]
std::vector<ck::index_t> c_ms_ns_lengths{5,6,3,4}; std::vector<ck::index_t> c_ms_ns_lengths{5,6,3,4};
//std::vector<ck::index_t> c_ms_ns_strides{108,20,16,1}; std::vector<ck::index_t> c_ms_ns_strides{108,20,16,1};
#elif 0
// fast changing dimension: K/K/N
// a[m0, m1, k0, k1]
std::vector<ck::index_t> a_ms_ks_lengths{5,6,3,4};
std::vector<ck::index_t> a_ms_ks_strides{108,20,16,1};
// b[k0, k1, n0, n1]
std::vector<ck::index_t> b_ks_ns_lengths{3,4,3,4};
std::vector<ck::index_t> b_ks_ns_strides{16,1,108,20};
// c[m0, m1, n0, n1]
std::vector<ck::index_t> c_ms_ns_lengths{5,6,3,4};
std::vector<ck::index_t> c_ms_ns_strides{108,20,16,1};
#elif 0
// fast changing dimension: M/N/N
// a[m0, m1, k0, k1]
std::vector<ck::index_t> a_ms_ks_lengths{5,6,3,4};
std::vector<ck::index_t> a_ms_ks_strides{6,1,72,24};
// b[k0, k1, n0, n1]
std::vector<ck::index_t> b_ks_ns_lengths{3,4,3,4};
std::vector<ck::index_t> b_ks_ns_strides{48,12,4,1};
// c[m0, m1, n0, n1]
std::vector<ck::index_t> c_ms_ns_lengths{5,6,3,4};
std::vector<ck::index_t> c_ms_ns_strides{108,20,16,1};
#elif 1
// fast changing dimension: M/K/N
// a[m0, m1, k0, k1]
std::vector<ck::index_t> a_ms_ks_lengths{5,6,3,4};
std::vector<ck::index_t> a_ms_ks_strides{6,1,72,24};
// b[k0, k1, n0, n1]
std::vector<ck::index_t> b_ks_ns_lengths{3,4,3,4};
std::vector<ck::index_t> b_ks_ns_strides{16,1,108,20};
// c[m0, m1, n0, n1]
std::vector<ck::index_t> c_ms_ns_lengths{5,6,3,4};
std::vector<ck::index_t> c_ms_ns_strides{108,20,16,1};
#endif #endif
#if 0
Tensor<ADataType> a_ms_ks( Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
...@@ -276,16 +336,6 @@ int main(int argc, char* argv[]) ...@@ -276,16 +336,6 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_ms_ns_device_result( Tensor<CDataType> c_ms_ns_device_result(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()), std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()),
std::vector<std::size_t>(c_ms_ns_strides.begin(), c_ms_ns_strides.end())); std::vector<std::size_t>(c_ms_ns_strides.begin(), c_ms_ns_strides.end()));
#else
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()));
Tensor<BDataType> b_ks_ns(
std::vector<std::size_t>(b_ks_ns_lengths.begin(), b_ks_ns_lengths.end()));
Tensor<CDataType> c_ms_ns_host_result(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()));
Tensor<CDataType> c_ms_ns_device_result(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()));
#endif
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ks_ns: " << b_ks_ns.mDesc << std::endl; std::cout << "b_ks_ns: " << b_ks_ns.mDesc << std::endl;
...@@ -330,7 +380,7 @@ int main(int argc, char* argv[]) ...@@ -330,7 +380,7 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = ContractionInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(static_cast<ADataType*>(a_ms_ks_device_buf.GetDeviceBuffer()), auto argument = op.MakeArgument(static_cast<ADataType*>(a_ms_ks_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_ks_ns_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_ks_ns_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment