Commit ea7a8fca authored by Chao Liu's avatar Chao Liu
Browse files

contraction with multiple D

parent 6ef4e211
add_example_executable(example_contraction_xdl_fp32 contraction_xdl_fp32.cpp)
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
# Instructions for ```example_contraction_xdl_fp32```
# Instructions for ```example_contraction_bilinear_xdl_fp32```
## Run
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_xdl_fp32 1 1 1
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
......@@ -16,5 +16,5 @@ c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContraction_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
```
......@@ -12,27 +12,48 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceContraction : public BaseOperator
typename CDEElementwiseOperation>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::vector<index_t> c_lengths,
std::vector<index_t> c_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -11,11 +11,14 @@ namespace ck {
namespace tensor_operation {
namespace device {
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DELayout,
......
......@@ -88,12 +88,15 @@ namespace ck {
namespace tensor_operation {
namespace device {
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DELayout,
......@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
......@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
......@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
......@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
// ck::Tuple<const DsDataType*...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
......@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
......
......@@ -17,12 +17,15 @@
namespace ck {
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment