Commit c51b3d29 authored by Paul's avatar Paul
Browse files

Some more simplifications

parent f8e5a547
...@@ -60,22 +60,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__( ...@@ -60,22 +60,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
namespace migraphx { namespace migraphx {
using gemm_t = CKDeviceGemm<${instance}, ${m}, ${k}, ${n}, ${sa}, ${sb}, ${sc}>; using gemm_t = CKDeviceGemm<${instance}>;
constexpr __device__ gemm_t ckdg{};
using GridwiseGemm = decltype(ckdg.gridwisegemm);
extern "C" { extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
make_tensors()(a_p, b_p, c_p)([&](auto a_t, auto b_t, auto c_t) { make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
constexpr ck::index_t shared_block_size = ck_gemm<gemm_t>(a, b, c);
GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ char p_shared_block[shared_block_size];
make_tensors()(p_shared_block)([&](auto p_t) {
ck_gemm<gemm_t>(a_t, b_t, c_t, p_t);
});
}); });
} }
......
#ifndef MIGRAPHX_GUARD_KERNELS_CK_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_HPP
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/tensor_description/tensor_descriptor.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
namespace migraphx {
namespace detail {
template<class T>
struct to_ck_type_impl
{
using type = T;
};
template<>
struct to_ck_type_impl<migraphx::half>
{
using type = ck::half_t;
};
template<class Shape>
constexpr bool is_row_major()
{
constexpr auto strides = Shape{}.strides;
MIGRAPHX_ASSERT(strides.size() >= 2);
if (strides.back() == 1)
{
MIGRAPHX_ASSERT(not Shape{}.is_trasnposed());
return true;
}
MIGRAPHX_ASSERT(strides[strides.size() - 2] == 1);
return false;
}
} // namespace detail
template<class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class Shape>
using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(), ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor>;
template<class Tensor>
constexpr auto to_ck_tensor()
{
constexpr auto s = get_shape_c<Tensor>{};
return sequence(s.lens.size(), [](auto... is) {
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[is]...), ck::make_tuple(s.strides[is]...));
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CK_HPP
...@@ -28,62 +28,43 @@ ...@@ -28,62 +28,43 @@
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp> #include <migraphx/kernels/ck_gemm_includes.hpp>
namespace migraphx { namespace migraphx {
template <class G, class T, class U, class V, class W>
__device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, W& p_t)
{
constexpr G ckdg{};
using GridwiseGemm = decltype(ckdg.gridwisegemm);
constexpr auto a_grid_desc_ak0_m_ak1 = ckdg.MakeAGridDescriptor_AK0_M_AK1(); template <class G, class A, class B, class C>
constexpr auto b_grid_desc_bk0_n_bk1 = ckdg.MakeBGridDescriptor_BK0_N_BK1(); __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr auto c_grid_desc_m_n = ckdg.MakeCGridDescriptor_M_N(); {
constexpr auto block_2_ctile_map = ckdg.MakeDefaultBlock2CTileMap(c_grid_desc_m_n); constexpr auto a_desc = to_ck_tensor<A>();
constexpr auto b_desc = to_ck_tensor<B>();
constexpr auto c_desc = to_ck_tensor<C>();
constexpr auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_desc);
// static_assert(GridwiseGemm::CheckValidity( using GridwiseGemm = typename G::template Make<a_desc, b_desc, c_desc>;
// a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map)); // static_assert(GridwiseGemm::CheckValidity(a_desc, b_desc, c_desc, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_desc);
constexpr auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); constexpr auto shared_block_size =
constexpr auto a_element_op = ckdg.a_element_op; GridwiseGemm::GetSharedMemoryNumberOfByte();
constexpr auto b_element_op = ckdg.b_element_op; __shared__ char p_shared_block[shared_block_size];
constexpr auto c_element_op = ckdg.c_element_op;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) constexpr bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(A{}.get_shape().elements());
{ GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(),
constexpr bool HasMainKBlockLoop = true; b.data(),
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(), c.data(),
b_t.data(), p_shared_block,
c_t.data(), G::AOp(),
p_t.data(), G::BOp(),
a_element_op, G::COp(),
b_element_op, a_desc,
c_element_op, b_desc,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
}
else
{
constexpr bool HasMainKBlockLoop = false;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -39,29 +39,6 @@ ...@@ -39,29 +39,6 @@
namespace migraphx { namespace migraphx {
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
static constexpr auto I4 = ck::Number<4>{};
static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N> template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt struct BlockToCTileMap_M00_N0_M01Adapt
{ {
...@@ -172,303 +149,12 @@ template <typename ALayout, ...@@ -172,303 +149,12 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock, ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t MRaw,
ck::index_t KRaw,
ck::index_t NRaw,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()> ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CKDeviceGemm struct CKDeviceGemm
{ {
// template<ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA> template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1() using Make =
{ ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, KRaw),
ck::make_tuple(StrideA, I1));
}
else if constexpr(ck::is_same_v<ck::tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, KRaw),
ck::make_tuple(I1, StrideA));
}
}();
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both M and K
static_assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad M, but not K
static_assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_right_pad_transform(MRaw, MPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad K, but not M
static_assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
static_assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
// template<ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1()
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(NRaw, KRaw),
ck::make_tuple(I1, StrideB));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(NRaw, KRaw),
ck::make_tuple(StrideB, I1));
}
}();
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both N and K
static_assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(NRaw, NPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad N, but not K
static_assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad K, but not N
static_assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(NRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
static_assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
// template<ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC>
static constexpr auto MakeCGridDescriptor_M_N()
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, NRaw),
ck::make_tuple(StrideC, I1));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, NRaw),
ck::make_tuple(I1, StrideC));
}
}();
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<8, 8, 8>());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<8, 8, 8>());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N<8, 8, 8>());
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1());
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -513,10 +199,17 @@ struct CKDeviceGemm ...@@ -513,10 +199,17 @@ struct CKDeviceGemm
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
GridwiseGemm gridwisegemm{}; static constexpr auto AOp() { return AElementwiseOperation{}; }
AElementwiseOperation a_element_op{}; static constexpr auto BOp() { return BElementwiseOperation{}; }
BElementwiseOperation b_element_op{}; static constexpr auto COp() { return CElementwiseOperation{}; }
CElementwiseOperation c_element_op{}; // return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
}; };
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment