Commit ea7a8fca authored by Chao Liu's avatar Chao Liu
Browse files

contraction with multiple D

parent 6ef4e211
add_example_executable(example_contraction_xdl_fp32 contraction_xdl_fp32.cpp)
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
# Instructions for ```example_contraction_xdl_fp32```
# Instructions for ```example_contraction_bilinear_xdl_fp32```
## Run
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_xdl_fp32 1 1 1
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
......@@ -16,5 +16,5 @@ c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContraction_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
```
......@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -22,33 +22,33 @@ using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContraction_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
//############################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on
// hardcoded for NumDimM == NumDimN == NumDimK == 2
......@@ -57,11 +57,11 @@ template <ck::index_t NumDimM,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename CDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
......@@ -70,26 +70,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ks_ns_{b_ks_ns},
c_ms_ns_{c_ms_ns},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ks_ns_;
Tensor<CDataType>& c_ms_ns_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
......@@ -123,16 +123,16 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.cde_element_op_(v_c, v_acc);
arg.c_ms_ns_(m0, m1, n0, n1) = v_c;
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3])(
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......@@ -158,12 +158,12 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns,
Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ks_ns, c_ms_ns, a_element_op, b_element_op, c_element_op};
return Argument{a_ms_ks, b_ks_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -186,17 +186,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
};
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
......@@ -217,15 +206,21 @@ int main(int argc, char* argv[])
exit(0);
}
const float alpha = 1;
const float beta = 1;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[K0, K1, N0, N1]
std::vector<ck::index_t> b_ks_ns_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ks_ns_strides{128, 1, 524288, 4096};
// C[M0, M1, N0, N1]
std::vector<ck::index_t> c_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> c_ms_ns_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
......@@ -233,16 +228,20 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_ks_ns(
std::vector<std::size_t>(b_ks_ns_lengths.begin(), b_ks_ns_lengths.end()),
std::vector<std::size_t>(b_ks_ns_strides.begin(), b_ks_ns_strides.end()));
Tensor<CDataType> c_ms_ns_host_result(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()),
std::vector<std::size_t>(c_ms_ns_strides.begin(), c_ms_ns_strides.end()));
Tensor<CDataType> c_ms_ns_device_result(
std::vector<std::size_t>(c_ms_ns_lengths.begin(), c_ms_ns_lengths.end()),
std::vector<std::size_t>(c_ms_ns_strides.begin(), c_ms_ns_strides.end()));
Tensor<EDataType> d_ms_ns(
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ks_ns: " << b_ks_ns.mDesc << std::endl;
std::cout << "c_ms_ns: " << c_ms_ns_host_result.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
......@@ -250,45 +249,54 @@ int main(int argc, char* argv[])
case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ks_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ks_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
a_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b_ks_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_ms_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_ms_ks_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_ks_ns_device_buf(sizeof(BDataType) * b_ks_ns.mDesc.GetElementSpace());
DeviceMem c_ms_ns_device_buf(sizeof(CDataType) * c_ms_ns_device_result.mDesc.GetElementSpace());
DeviceMem d_ms_ns_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace());
DeviceMem e_ms_ns_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
a_ms_ks_device_buf.ToDevice(a_ms_ks.mData.data());
b_ks_ns_device_buf.ToDevice(b_ks_ns.mData.data());
d_ms_ns_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero
c_ms_ns_device_buf.SetZero();
e_ms_ns_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(static_cast<ADataType*>(a_ms_ks_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_ks_ns_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_ms_ns_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ks_ns_lengths,
b_ks_ns_strides,
c_ms_ns_lengths,
c_ms_ns_strides,
a_element_op,
b_element_op,
c_element_op);
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument =
op.MakeArgument(a_ms_ks_device_buf.GetDeviceBuffer(),
b_ks_ns_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_ms_ns_device_buf.GetDeviceBuffer()},
e_ms_ns_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ks_ns_lengths,
b_ks_ns_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
......@@ -299,13 +307,13 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(c_ms_ns_lengths.begin(),
c_ms_ns_lengths.begin() + NumDimM,
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(c_ms_ns_lengths.begin() + NumDimM,
c_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
......@@ -314,9 +322,9 @@ int main(int argc, char* argv[])
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -325,19 +333,50 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
c_ms_ns_device_buf.FromDevice(c_ms_ns_device_result.mData.data());
e_ms_ns_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ks_ns, c_ms_ns_host_result, a_element_op, b_element_op, c_element_op);
a_ms_ks, b_ks_ns, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_ms_ns_device_result.mData, c_ms_ns_host_result.mData) ? 0 : 1;
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1;
}
return 0;
......
......@@ -12,27 +12,48 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceContraction : public BaseOperator
typename CDEElementwiseOperation>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::vector<index_t> c_lengths,
std::vector<index_t> c_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -10,31 +10,109 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...]
// C[M0, M1, M2, ..., N0, N1, N2...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -63,17 +141,24 @@ template <index_t NumDimM,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
NumDimN,
NumDimK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceContraction_Xdl_CShuffle;
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -335,19 +420,19 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
}
}
// assume C[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_lengths_vec,
const std::vector<index_t>& c_strides_vec)
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_lengths_vec,
const std::vector<index_t>& e_strides_vec)
{
assert(c_lengths_vec.size() == NumDimM + NumDimN &&
c_strides_vec.size() == NumDimM + NumDimN);
assert(e_lengths_vec.size() == NumDimM + NumDimN &&
e_strides_vec.size() == NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto Num) {
return generate_tuple([&](auto i) { return vec[i]; }, Num);
};
const auto c_lengths = to_tuple(c_lengths_vec, Number<NumDimM + NumDimN>{});
const auto c_strides = to_tuple(c_strides_vec, Number<NumDimM + NumDimN>{});
const auto e_lengths = to_tuple(e_lengths_vec, Number<NumDimM + NumDimN>{});
const auto e_strides = to_tuple(e_strides_vec, Number<NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
......@@ -357,23 +442,23 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_lengths, mDimIds);
const auto mLengths = get_container_subset(e_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_lengths, nDimIds);
const auto nLengths = get_container_subset(e_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns = make_naive_tensor_descriptor(c_lengths, c_strides);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_ms_ns = make_naive_tensor_descriptor(e_lengths, e_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
e_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
......@@ -385,7 +470,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -396,7 +481,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -406,7 +491,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -414,7 +499,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
return e_grid_desc_mraw_nraw;
}
}
......@@ -422,22 +507,23 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector<index_t>{}, std::vector<index_t>{}));
using BGridDesc_BK0_N_BK1 =
decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector<index_t>{}, std::vector<index_t>{}));
using CGridDesc_M_N =
decltype(MakeCGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{}));
using EGridDesc_M_N =
decltype(MakeEGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{}));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -467,36 +553,41 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::vector<index_t> c_lengths,
std::vector<index_t> c_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_lengths, a_strides)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_lengths, b_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_lengths, c_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_lengths, e_strides)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_mz_length_{},
a_mz_stride_{},
a_kz_length_{},
......@@ -505,19 +596,32 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
b_nz_stride_{},
b_kz_length_{},
b_kz_stride_{},
c_mz_length_{},
c_mz_stride_{},
c_nz_length_{},
c_nz_stride_{}
e_mz_length_{},
e_mz_stride_{},
e_nz_length_{},
e_nz_stride_{}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
e_grid_desc_m_n_,
block_2_etile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N(ds_lengths[i], ds_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
}
// for sanity check of vector load/store
......@@ -533,30 +637,43 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
b_kz_length_ = b_lengths[NumDimK + NumDimN - 1];
b_kz_stride_ = b_strides[NumDimK + NumDimN - 1];
c_mz_length_ = b_lengths[NumDimM - 1];
c_mz_stride_ = b_strides[NumDimM - 1];
e_mz_length_ = b_lengths[NumDimM - 1];
e_mz_stride_ = b_strides[NumDimM - 1];
c_nz_length_ = b_lengths[NumDimM + NumDimN - 1];
c_nz_stride_ = b_strides[NumDimM + NumDimN - 1];
e_nz_length_ = b_lengths[NumDimM + NumDimN - 1];
e_nz_stride_ = b_strides[NumDimM + NumDimN - 1];
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// These are last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
// FIXME: add check for D0, D1, ...
index_t a_mz_length_;
index_t a_mz_stride_;
index_t a_kz_length_;
......@@ -565,10 +682,10 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
index_t b_nz_stride_;
index_t b_kz_length_;
index_t b_kz_stride_;
index_t c_mz_length_;
index_t c_mz_stride_;
index_t c_nz_length_;
index_t c_nz_stride_;
index_t e_mz_length_;
index_t e_mz_stride_;
index_t e_nz_length_;
index_t e_nz_stride_;
};
// Invoker
......@@ -590,89 +707,73 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
......@@ -701,8 +802,8 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
return false;
}
......@@ -718,61 +819,74 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::vector<index_t> c_lengths,
std::vector<index_t> c_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_ds,
p_e,
a_lengths,
a_strides,
b_lengths,
b_strides,
c_lengths,
c_strides,
ds_lengths,
ds_strides,
e_lengths,
e_strides,
a_element_op,
b_element_op,
c_element_op};
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::vector<index_t> c_lengths,
std::vector<index_t> c_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_lengths,
a_strides,
b_lengths,
b_strides,
c_lengths,
c_strides,
ds_lengths,
ds_strides,
e_lengths,
e_strides,
a_element_op,
b_element_op,
c_element_op);
cde_element_op);
}
// polymorphic
......@@ -787,7 +901,7 @@ struct DeviceContraction_Xdl_CShuffle : public DeviceContraction<NumDimM,
auto str = std::stringstream();
// clang-format off
str << "DeviceContraction_Xdl_CShuffle"
str << "DeviceContractionMultipleD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -11,11 +11,14 @@ namespace ck {
namespace tensor_operation {
namespace device {
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DELayout,
......
......@@ -88,12 +88,15 @@ namespace ck {
namespace tensor_operation {
namespace device {
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DELayout,
......@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
......@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
......@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
......@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
// ck::Tuple<const DsDataType*...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
......@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
......
......@@ -17,12 +17,15 @@
namespace ck {
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment