Commit ea7a8fca authored by Chao Liu's avatar Chao Liu
Browse files

contraction with multiple D

parent 6ef4e211
add_example_executable(example_contraction_xdl_fp32 contraction_xdl_fp32.cpp) add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
# Instructions for ```example_contraction_xdl_fp32``` # Instructions for ```example_contraction_bilinear_xdl_fp32```
## Run ## Run
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes) #arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_xdl_fp32 1 1 1 ./bin/example_contraction_bilinear_xdl_fp32 1 1 1
``` ```
Result (MI100 @ dynammic freq, 46TFlops peak FP32) Result (MI100 @ dynammic freq, 46TFlops peak FP32)
...@@ -16,5 +16,5 @@ c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} ...@@ -16,5 +16,5 @@ c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time Warm up 1 time
Start running 10 times... Start running 10 times...
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContraction_Xdl_CShuffle<256, 256, 128, 16, 4, 4> Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
``` ```
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -22,33 +22,33 @@ using S = ck::Sequence<Is...>; ...@@ -22,33 +22,33 @@ using S = ck::Sequence<Is...>;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = float; using ADataType = F32;
using BDataType = float; using BDataType = F32;
using CDataType = float; using AccDataType = F32;
using AccDataType = float; using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2; static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device:: using DeviceOpInstance = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //############################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on // clang-format on
// hardcoded for NumDimM == NumDimN == NumDimK == 2 // hardcoded for NumDimM == NumDimN == NumDimK == 2
...@@ -57,11 +57,11 @@ template <ck::index_t NumDimM, ...@@ -57,11 +57,11 @@ template <ck::index_t NumDimM,
ck::index_t NumDimK, ck::index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{ {
...@@ -70,26 +70,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -70,26 +70,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
Argument(const Tensor<ADataType>& a_ms_ks, Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns, const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns, Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks}, : a_ms_ks_{a_ms_ks},
b_ks_ns_{b_ks_ns}, b_ks_ns_{b_ks_ns},
c_ms_ns_{c_ms_ns}, e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} cde_element_op_{cde_element_op}
{ {
} }
const Tensor<ADataType>& a_ms_ks_; const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ks_ns_; const Tensor<BDataType>& b_ks_ns_;
Tensor<CDataType>& c_ms_ns_; Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CDEElementwiseOperation cde_element_op_;
}; };
// Invoker // Invoker
...@@ -123,16 +123,16 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -123,16 +123,16 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.cde_element_op_(v_c, v_acc);
arg.c_ms_ns_(m0, m1, n0, n1) = v_c; arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[0], arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1], arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2], arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3])( arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -158,12 +158,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -158,12 +158,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks, static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns, const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns, Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{a_ms_ks, b_ks_ns, c_ms_ns, a_element_op, b_element_op, c_element_op}; return Argument{a_ms_ks, b_ks_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -186,17 +186,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -186,17 +186,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
}; };
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -217,15 +206,21 @@ int main(int argc, char* argv[]) ...@@ -217,15 +206,21 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
const float alpha = 1;
const float beta = 1;
// A[M0, M1, K0, K1] // A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64}; std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1}; std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[K0, K1, N0, N1] // B[K0, K1, N0, N1]
std::vector<ck::index_t> b_ks_ns_lengths{32, 64, 32, 64}; std::vector<ck::index_t> b_ks_ns_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ks_ns_strides{128, 1, 524288, 4096}; std::vector<ck::index_t> b_ks_ns_strides{128, 1, 524288, 4096};
// C[M0, M1, N0, N1] // D[M0, M1, N0, N1]
std::vector<ck::index_t> c_ms_ns_lengths{30, 128, 32, 64}; std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> c_ms_ns_strides{524288, 4096, 128, 1}; std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
Tensor<ADataType> a_ms_ks( Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
...@@ -233,16 +228,20 @@ int main(int argc, char* argv[]) ...@@ -233,16 +228,20 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_ks_ns( Tensor<BDataType> b_ks_ns(
std::vector<std::size_t>(b_ks_ns_lengths.begin(), b_ks_ns_lengths.end()), std::vector<std::size_t>(b_ks_ns_lengths.begin(), b_ks_ns_lengths.end()),
std::vector<std::size_t>(b_ks_ns_strides.begin(), b_ks_ns_strides.end())); std::vector<std::size_t>(b_ks_ns_strides.begin(), b_ks_ns_strides.end()));
Tensor<CDataType> c_ms_ns_host_result( Tensor<EDataType> d_ms_ns(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()), std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(c_ms_ns_strides.begin(), c_ms_ns_strides.end())); std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<CDataType> c_ms_ns_device_result( Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()), std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(c_ms_ns_strides.begin(), c_ms_ns_strides.end())); std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ks_ns: " << b_ks_ns.mDesc << std::endl; std::cout << "b_ks_ns: " << b_ks_ns.mDesc << std::endl;
std::cout << "c_ms_ns: " << c_ms_ns_host_result.mDesc << std::endl; std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -250,45 +249,54 @@ int main(int argc, char* argv[]) ...@@ -250,45 +249,54 @@ int main(int argc, char* argv[])
case 1: case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ks_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ks_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
case 2: case 2:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ks_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ks_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b_ks_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_ks_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_ms_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
DeviceMem a_ms_ks_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); DeviceMem a_ms_ks_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_ks_ns_device_buf(sizeof(BDataType) * b_ks_ns.mDesc.GetElementSpace()); DeviceMem b_ks_ns_device_buf(sizeof(BDataType) * b_ks_ns.mDesc.GetElementSpace());
DeviceMem c_ms_ns_device_buf(sizeof(CDataType) * c_ms_ns_device_result.mDesc.GetElementSpace()); DeviceMem d_ms_ns_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace());
DeviceMem e_ms_ns_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
a_ms_ks_device_buf.ToDevice(a_ms_ks.mData.data()); a_ms_ks_device_buf.ToDevice(a_ms_ks.mData.data());
b_ks_ns_device_buf.ToDevice(b_ks_ns.mData.data()); b_ks_ns_device_buf.ToDevice(b_ks_ns.mData.data());
d_ms_ns_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero // set zero
c_ms_ns_device_buf.SetZero(); e_ms_ns_device_buf.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{alpha, beta};
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(static_cast<ADataType*>(a_ms_ks_device_buf.GetDeviceBuffer()), auto argument =
static_cast<BDataType*>(b_ks_ns_device_buf.GetDeviceBuffer()), op.MakeArgument(a_ms_ks_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_ms_ns_device_buf.GetDeviceBuffer()), b_ks_ns_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths, std::array<const void*, 1>{d_ms_ns_device_buf.GetDeviceBuffer()},
a_ms_ks_strides, e_ms_ns_device_buf.GetDeviceBuffer(),
b_ks_ns_lengths, a_ms_ks_lengths,
b_ks_ns_strides, a_ms_ks_strides,
c_ms_ns_lengths, b_ks_ns_lengths,
c_ms_ns_strides, b_ks_ns_strides,
a_element_op, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
b_element_op, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
c_element_op); e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument)) if(!op.IsSupportedArgument(argument))
{ {
...@@ -299,13 +307,13 @@ int main(int argc, char* argv[]) ...@@ -299,13 +307,13 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(c_ms_ns_lengths.begin(), ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
c_ms_ns_lengths.begin() + NumDimM, e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(c_ms_ns_lengths.begin() + NumDimM, ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
c_ms_ns_lengths.begin() + NumDimM + NumDimN, e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
...@@ -314,9 +322,9 @@ int main(int argc, char* argv[]) ...@@ -314,9 +322,9 @@ int main(int argc, char* argv[])
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -325,19 +333,50 @@ int main(int argc, char* argv[]) ...@@ -325,19 +333,50 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl; << op.GetTypeString() << std::endl;
c_ms_ns_device_buf.FromDevice(c_ms_ns_device_result.mData.data()); e_ms_ns_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ks_ns, c_ms_ns_host_result, a_element_op, b_element_op, c_element_op); a_ms_ks, b_ks_ns, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_ms_ns_device_result.mData, c_ms_ns_host_result.mData) ? 0 : 1; for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -12,27 +12,48 @@ namespace ck { ...@@ -12,27 +12,48 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceContraction : public BaseOperator struct DeviceContractionMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths, std::vector<index_t> a_lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_lengths, std::vector<index_t> b_lengths,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
std::vector<index_t> c_lengths, std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::vector<index_t> c_strides, std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -10,31 +10,109 @@ ...@@ -10,31 +10,109 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// A[M0, M1, M2, ..., K0, K1, K2...] // A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...] // B[K0, K1, K2, ..., N0, N1, N2...]
// C[M0, M1, M2, ..., N0, N1, N2...] // D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -63,17 +141,24 @@ template <index_t NumDimM, ...@@ -63,17 +141,24 @@ template <index_t NumDimM,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, struct DeviceContractionMultipleD_Xdl_CShuffle
NumDimN, : public DeviceContractionMultipleD<NumDimM,
NumDimK, NumDimN,
AElementwiseOperation, NumDimK,
BElementwiseOperation, ADataType,
CElementwiseOperation> BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{ {
using DeviceOp = DeviceContraction_Xdl_CShuffle; using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -335,19 +420,19 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -335,19 +420,19 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
} }
} }
// assume C[M0, M1, M2, ..., N0, N1, N2...] // assume E[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_lengths_vec, static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_lengths_vec,
const std::vector<index_t>& c_strides_vec) const std::vector<index_t>& e_strides_vec)
{ {
assert(c_lengths_vec.size() == NumDimM + NumDimN && assert(e_lengths_vec.size() == NumDimM + NumDimN &&
c_strides_vec.size() == NumDimM + NumDimN); e_strides_vec.size() == NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto Num) { const auto to_tuple = [&](auto& vec, auto Num) {
return generate_tuple([&](auto i) { return vec[i]; }, Num); return generate_tuple([&](auto i) { return vec[i]; }, Num);
}; };
const auto c_lengths = to_tuple(c_lengths_vec, Number<NumDimM + NumDimN>{}); const auto e_lengths = to_tuple(e_lengths_vec, Number<NumDimM + NumDimN>{});
const auto c_strides = to_tuple(c_strides_vec, Number<NumDimM + NumDimN>{}); const auto e_strides = to_tuple(e_strides_vec, Number<NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ... // dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
...@@ -357,23 +442,23 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -357,23 +442,23 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{}; typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ... // lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_lengths, mDimIds); const auto mLengths = get_container_subset(e_lengths, mDimIds);
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_lengths, nDimIds); const auto nLengths = get_container_subset(e_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...] // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns = make_naive_tensor_descriptor(c_lengths, c_strides); const auto e_grid_desc_ms_ns = make_naive_tensor_descriptor(e_lengths, e_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor( const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns, e_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds), make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0); const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1); const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
...@@ -385,7 +470,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -385,7 +470,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -396,7 +481,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -396,7 +481,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -406,7 +491,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -406,7 +491,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -414,7 +499,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -414,7 +499,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
else else
{ {
// not pad M or N // not pad M or N
return c_grid_desc_mraw_nraw; return e_grid_desc_mraw_nraw;
} }
} }
...@@ -422,22 +507,23 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -422,22 +507,23 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector<index_t>{}, std::vector<index_t>{})); decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector<index_t>{}, std::vector<index_t>{}));
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector<index_t>{}, std::vector<index_t>{})); decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector<index_t>{}, std::vector<index_t>{}));
using CGridDesc_M_N = using EGridDesc_M_N =
decltype(MakeCGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{})); decltype(MakeEGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{}));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -467,36 +553,41 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -467,36 +553,41 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
BBlockLdsExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const void* p_a_grid,
const BDataType* p_b_grid, const void* p_b_grid,
CDataType* p_c_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
std::vector<index_t> a_lengths, std::vector<index_t> a_lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_lengths, std::vector<index_t> b_lengths,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
std::vector<index_t> c_lengths, std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::vector<index_t> c_strides, std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation cde_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{p_b_grid}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_c_grid_{p_c_grid}, p_ds_grid_{}, // FIXME
a_element_op_{a_element_op}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_lengths, a_strides)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_lengths, a_strides)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_lengths, b_strides)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_lengths, b_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_lengths, c_strides)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_lengths, e_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_mz_length_{}, a_mz_length_{},
a_mz_stride_{}, a_mz_stride_{},
a_kz_length_{}, a_kz_length_{},
...@@ -505,19 +596,32 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -505,19 +596,32 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
b_nz_stride_{}, b_nz_stride_{},
b_kz_length_{}, b_kz_length_{},
b_kz_stride_{}, b_kz_stride_{},
c_mz_length_{}, e_mz_length_{},
c_mz_stride_{}, e_mz_stride_{},
c_nz_length_{}, e_nz_length_{},
c_nz_stride_{} e_nz_stride_{}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_ctile_map_)) block_2_etile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N(ds_lengths[i], ds_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
} }
// for sanity check of vector load/store // for sanity check of vector load/store
...@@ -533,30 +637,43 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -533,30 +637,43 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
b_kz_length_ = b_lengths[NumDimK + NumDimN - 1]; b_kz_length_ = b_lengths[NumDimK + NumDimN - 1];
b_kz_stride_ = b_strides[NumDimK + NumDimN - 1]; b_kz_stride_ = b_strides[NumDimK + NumDimN - 1];
c_mz_length_ = b_lengths[NumDimM - 1]; e_mz_length_ = b_lengths[NumDimM - 1];
c_mz_stride_ = b_strides[NumDimM - 1]; e_mz_stride_ = b_strides[NumDimM - 1];
c_nz_length_ = b_lengths[NumDimM + NumDimN - 1]; e_nz_length_ = b_lengths[NumDimM + NumDimN - 1];
c_nz_stride_ = b_strides[NumDimM + NumDimN - 1]; e_nz_stride_ = b_strides[NumDimM + NumDimN - 1];
} }
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; StaticallyIndexedArray<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
c_grid_desc_mblock_mperblock_nblock_nperblock_; NumDTensor>
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// These are last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store // for sanity check of vector load/store
// FIXME: add check for D0, D1, ...
index_t a_mz_length_; index_t a_mz_length_;
index_t a_mz_stride_; index_t a_mz_stride_;
index_t a_kz_length_; index_t a_kz_length_;
...@@ -565,10 +682,10 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -565,10 +682,10 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
index_t b_nz_stride_; index_t b_nz_stride_;
index_t b_kz_length_; index_t b_kz_length_;
index_t b_kz_stride_; index_t b_kz_stride_;
index_t c_mz_length_; index_t e_mz_length_;
index_t c_mz_stride_; index_t e_mz_stride_;
index_t c_nz_length_; index_t e_nz_length_;
index_t c_nz_stride_; index_t e_nz_stride_;
}; };
// Invoker // Invoker
...@@ -590,89 +707,73 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -590,89 +707,73 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, ck::StaticallyIndexedArray<
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
true>; NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ave_time = typename GridwiseGemm::DefaultBlock2ETileMap,
launch_and_time_kernel(stream_config, has_main_loop>;
kernel,
dim3(grid_size), return launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.a_element_op_, arg.p_b_grid_,
arg.b_element_op_, arg.p_ds_grid_,
arg.c_element_op_, arg.p_e_grid_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_element_op_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.cde_element_op_,
arg.block_2_ctile_map_); arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< ave_time = launch_kernel(integral_constant<bool, false>{});
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -701,8 +802,8 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -701,8 +802,8 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_etile_map_))
{ {
return false; return false;
} }
...@@ -718,61 +819,74 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -718,61 +819,74 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const void* p_a,
const BDataType* p_b, const void* p_b,
CDataType* p_c, std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths, std::vector<index_t> a_lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_lengths, std::vector<index_t> b_lengths,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
std::vector<index_t> c_lengths, std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::vector<index_t> c_strides, std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_ds,
p_e,
a_lengths, a_lengths,
a_strides, a_strides,
b_lengths, b_lengths,
b_strides, b_strides,
c_lengths, ds_lengths,
c_strides, ds_strides,
e_lengths,
e_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
std::vector<index_t> a_lengths, std::array<const void*, NumDTensor> p_ds,
std::vector<index_t> a_strides, void* p_e,
std::vector<index_t> b_lengths, std::vector<index_t> a_lengths,
std::vector<index_t> b_strides, std::vector<index_t> a_strides,
std::vector<index_t> c_lengths, std::vector<index_t> b_lengths,
std::vector<index_t> c_strides, std::vector<index_t> b_strides,
AElementwiseOperation a_element_op, std::array<std::vector<index_t>, NumDTensor> ds_lengths,
BElementwiseOperation b_element_op, std::array<std::vector<index_t>, NumDTensor> ds_strides,
CElementwiseOperation c_element_op) override std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(p_a,
static_cast<const BDataType*>(p_b), p_b,
static_cast<CDataType*>(p_c), p_ds,
p_e,
a_lengths, a_lengths,
a_strides, a_strides,
b_lengths, b_lengths,
b_strides, b_strides,
c_lengths, ds_lengths,
c_strides, ds_strides,
e_lengths,
e_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
} }
// polymorphic // polymorphic
...@@ -787,7 +901,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM, ...@@ -787,7 +901,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceContraction_Xdl_CShuffle" str << "DeviceContractionMultipleD_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -11,11 +11,14 @@ namespace ck { ...@@ -11,11 +11,14 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// input : A[M, K], B[K, N], // GEMM:
// input : D0[M, N], D1[M, N], ... // input : A[M, K], B[K, N],
// output : E[M, N] // input : D0[M, N], D1[M, N], ...
// C = a_op(A) * b_op(B) // output : E[M, N]
// E = cde_op(C, D0, D1, ...) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DELayout,
......
...@@ -88,12 +88,15 @@ namespace ck { ...@@ -88,12 +88,15 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// input : A[M, K], or A[K, N] // GEMM:
// input : B[K, N], or A[N, K] // input : A[AK0, M, AK1]
// input : D0[M, N], D1[M, N], ... // input : B[AK0, N, AK1]
// output : E[M, N] // input : D0[M, N], D1[M, N], ...
// C = a_op(A) * b_op(B) // output : E[M, N]
// E = cde_op(C, D0, D1, ...) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DELayout,
...@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
...@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
...@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n = const auto d_grid_desc_m_n =
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
} }
// ck::Tuple<const DsDataType*...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< StaticallyIndexedArray<
...@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
......
...@@ -17,12 +17,15 @@ ...@@ -17,12 +17,15 @@
namespace ck { namespace ck {
// input : A[AK0, M, AK1] // GEMM:
// input : B[AK0, N, AK1] // input : A[AK0, M, AK1]
// input : D0[M, N], D1[M, N], ... // input : B[AK0, N, AK1]
// output : E[M, N] // input : D0[M, N], D1[M, N], ...
// C = a_op(A) * b_op(B) // output : E[M, N]
// E = cde_op(C, D0, D1, ...) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment