Commit dfbb659a authored by Chao Liu's avatar Chao Liu
Browse files

upate contraction example

parent 809799bf
...@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator ...@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
std::vector<index_t> a_ms_ks_lengths, const std::vector<index_t>& a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths, const std::vector<index_t>& e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides, const std::vector<index_t>& e_ms_ns_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -106,7 +107,7 @@ template <index_t NumDimM, ...@@ -106,7 +107,7 @@ template <index_t NumDimM,
index_t NumDimK, index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
...@@ -165,9 +166,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -165,9 +166,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_ms_ks_lengths_vec, static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_vec,
const std::vector<index_t>& a_ms_ks_strides_vec) const std::vector<index_t>& a_ms_ks_strides_vec)
{ {
assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
a_ms_ks_strides_vec.size() == NumDimM + NumDimK); a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
...@@ -203,100 +207,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -203,100 +207,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
make_tuple(mDimIds, kDimIds), make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0); return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_ns_ks_lengths_vec, static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_vec,
const std::vector<index_t>& b_ns_ks_strides_vec) const std::vector<index_t>& b_ns_ks_strides_vec)
{ {
assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
b_ns_ks_strides_vec.size() == NumDimN + NumDimK); b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
...@@ -332,95 +248,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -332,95 +248,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
make_tuple(nDimIds, kDimIds), make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
// assume E[M0, M1, M2, ..., N0, N1, N2...] // assume E[M0, M1, M2, ..., N0, N1, N2...]
...@@ -461,63 +289,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -461,63 +289,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
make_tuple(mDimIds, nDimIds), make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0); return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1); }
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || static auto MakeDsGridDescriptor_M_N(
GemmSpec == GemmSpecialization::MNKPadding) const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths_vec,
{ const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides_vec)
// pad M and N {
return transform_tensor_descriptor(e_grid_desc_mraw_nraw, return generate_tuple(
make_tuple(make_right_pad_transform(MRaw, MPad), [&](auto i) {
make_right_pad_transform(NRaw, NPad)), return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i],
make_tuple(Sequence<0>{}, Sequence<1>{}), ds_ms_ns_strides_vec[i]);
make_tuple(Sequence<0>{}, Sequence<1>{})); },
} Number<NumDTensor>{});
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return e_grid_desc_mraw_nraw;
}
} }
using AGridDesc_AK0_M_AK1 = using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector<index_t>{}, std::vector<index_t>{})); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
using BGridDesc_BK0_N_BK1 = using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector<index_t>{}, std::vector<index_t>{})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using EGridDesc_M_N =
decltype(MakeEGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{}));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
EDataType, EDataType,
...@@ -525,8 +320,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -525,8 +320,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_M_K,
BGridDesc_BK0_N_BK1, BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -561,6 +357,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -561,6 +357,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -568,27 +371,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -568,27 +371,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
std::vector<index_t> a_ms_ns_lengths, const std::vector<index_t>& a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths, const std::vector<index_t>& e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides, const std::vector<index_t>& e_ms_ns_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_ms_ns_lengths, a_ms_ks_strides)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_ns_ks_lengths, b_ns_ks_strides)}, GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -601,8 +407,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -601,8 +407,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
ds_nz_stride_{}, ds_nz_stride_{},
e_nz_stride_{} e_nz_stride_{}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, // populate pointer, batch stride, desc for Ds
b_grid_desc_bk0_n_bk1_, static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
...@@ -610,18 +430,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -610,18 +430,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) { ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
} }
// for sanity check of vector memory access // for sanity check of vector memory access
...@@ -639,6 +450,15 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -639,6 +450,15 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
} }
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private: // private:
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -646,20 +466,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -646,20 +466,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, ds_grid_desc_mblock_mperblock_nblock_nperblock_;
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -684,29 +506,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -684,29 +506,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
{ arg.b_grid_desc_n_k_,
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" arg.ds_grid_desc_m_n_,
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
...@@ -728,9 +535,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -728,9 +535,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
...@@ -754,18 +559,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -754,18 +559,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); return launch_kernel(integral_constant<bool, false>{});
} }
return ave_time;
} }
// polymorphic // polymorphic
...@@ -776,12 +577,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -776,12 +577,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
}; };
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
...@@ -789,8 +584,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -789,8 +584,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false; return false;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
...@@ -878,14 +674,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -878,14 +674,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
std::vector<index_t> a_ms_ns_lengths, const std::vector<index_t>& a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths, const std::vector<index_t>& e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides, const std::vector<index_t>& e_ms_ns_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
...@@ -915,14 +711,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -915,14 +711,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
std::vector<index_t> a_ms_ns_lengths, const std::vector<index_t>& a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides, const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths, const std::vector<index_t>& e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides, const std::vector<index_t>& e_ms_ns_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override CDEElementwiseOperation cde_element_op) override
......
...@@ -434,9 +434,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -434,9 +434,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment