Commit d1a1a28b authored by Paul's avatar Paul
Browse files

Support fused version of gemm with passthrough

parent 84a7381d
...@@ -28,4 +28,4 @@ half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0 ...@@ -28,4 +28,4 @@ half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@639147432b6922bd8e4051ba751e4e63dd4eb196 -X header ROCmSoftwarePlatform/composable_kernel@1b62bfaa2a42ed83da2692f6797a5f929c39946f -X header
...@@ -63,7 +63,7 @@ extern "C" { ...@@ -63,7 +63,7 @@ extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) { make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
ck_gemm<CKDeviceGemm<${instance}>>(a, b, c); ck_gemm<CK_DeviceGemmMultipleD<${instance}>>(a, b, c);
}); });
} }
...@@ -75,7 +75,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) ...@@ -75,7 +75,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; } static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
static std::size_t block_size_index = 13; static std::size_t block_size_index = 15;
static std::size_t get_block_size(const std::vector<std::string>& s) static std::size_t get_block_size(const std::vector<std::string>& s)
{ {
...@@ -154,8 +154,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -154,8 +154,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto i = v.get("tuning_val", get_tuning_for(inputs)); auto i = v.get("tuning_val", get_tuning_for(inputs));
const auto& instance = get_instance(i, [&](const auto& x) -> bool { const auto& instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[4] and get_type(c_shape) == x[5]; get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
}); });
hip_compile_options options; hip_compile_options options;
......
This diff is collapsed.
...@@ -31,7 +31,7 @@ constexpr bool is_row_major() ...@@ -31,7 +31,7 @@ constexpr bool is_row_major()
MIGRAPHX_ASSERT(strides.size() >= 2); MIGRAPHX_ASSERT(strides.size() >= 2);
if(strides.back() == 1) if(strides.back() == 1)
{ {
MIGRAPHX_ASSERT(not Shape{}.is_trasnposed()); MIGRAPHX_ASSERT(not Shape{}.is_transposed());
return true; return true;
} }
MIGRAPHX_ASSERT(strides[strides.size() - 2] == 1); MIGRAPHX_ASSERT(strides[strides.size() - 2] == 1);
...@@ -59,5 +59,26 @@ constexpr auto to_ck_tensor() ...@@ -59,5 +59,26 @@ constexpr auto to_ck_tensor()
}); });
} }
template<class F>
struct ck_function_adaptor : F
{
template<class... Ts>
constexpr ck_function_adaptor(Ts&&... xs) : F(static_cast<Ts&&>(xs)...)
{}
template<class T, class... Ts>
constexpr void operator()(T& out, Ts&&... xs) const
{
out = static_cast<const F&>(*this)(static_cast<Ts&&>(xs)...);
}
};
struct ck_nop
{
template<class T>
constexpr void operator()(T&) const
{}
};
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CK_HPP #endif // MIGRAPHX_GUARD_KERNELS_CK_HPP
...@@ -33,38 +33,48 @@ ...@@ -33,38 +33,48 @@
namespace migraphx { namespace migraphx {
template <class G, class A, class B, class C> template <class G, class A, class B, class E, class... Ds>
__device__ void ck_gemm(const A& a, const B& b, const C& c) __device__ void ck_gemm(A a, B b, E e, Ds... ds)
{ {
constexpr const auto a_grid_desc_ak0_m_ak1 = G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>()); constexpr const G gemm{};
constexpr const auto b_grid_desc_bk0_n_bk1 = G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
decltype(b_grid_desc_bk0_n_bk1), constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<B>());
decltype(c_grid_desc_m_n)>; constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, constexpr const auto ds_grid_desc_m_n = ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
// c_grid_desc_m_n, block_2_ctile_map)); constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock = using GridwiseGemm = typename G::GridwiseGemm;
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop = constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements()); GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
b.data(), b.data(),
c.data(), ck::make_tuple(ds.data()...),
e.data(),
p_shared_block, p_shared_block,
G{}.a_element_op, gemm.a_element_op,
G{}.b_element_op, gemm.b_element_op,
G{}.c_element_op, gemm.cde_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -36,6 +36,8 @@ ...@@ -36,6 +36,8 @@
#include <ck/tensor_operation/gpu/device/device_gemm.hpp> #include <ck/tensor_operation/gpu/device/device_gemm.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp> #include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp> #include <ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp>
#include <ck/tensor_operation/gpu/device/matrix_padder.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp>
namespace migraphx { namespace migraphx {
...@@ -110,15 +112,17 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -110,15 +112,17 @@ struct BlockToCTileMap_M00_N0_M01Adapt
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename DsLayout,
typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename AccDataType,
typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
ck::tensor_operation::device::GemmSpecialization GemmSpec, ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage, ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize, ck::index_t BlockSize,
...@@ -137,294 +141,34 @@ template <typename ALayout, ...@@ -137,294 +141,34 @@ template <typename ALayout,
ck::index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1, ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM, ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1, ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock, ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()> ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CKDeviceGemm struct CK_DeviceGemmMultipleD
{ {
static constexpr auto I0 = ck::Number<0>{}; ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t> matrix_padder {MPerBlock, NPerBlock, KPerBlock};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
template <class Descriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
{
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_right_pad_transform(MRaw, MPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
template <class Descriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
{
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1; // GridwiseGemm
using GridwiseGemm = ck::GridwiseGemmMultipleD_xdl_cshuffle<
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(NRaw, NPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(NRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <class Descriptor>
static constexpr auto MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
{
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<8, 8, 8>());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<8, 8, 8>());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N<8, 8, 8>());
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping
template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
ck::InMemoryDataOperationEnum::Set, ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -454,13 +198,22 @@ struct CKDeviceGemm ...@@ -454,13 +198,22 @@ struct CKDeviceGemm
BBlockLdsExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
// return block_id to E matrix tile idx (m0, n0) mapping
template<class EGridDesc_M_N>
__device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n_);
}
AElementwiseOperation a_element_op{}; AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{}; BElementwiseOperation b_element_op{};
CElementwiseOperation c_element_op{}; CDEElementwiseOperation cde_element_op{};
}; };
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment