"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "69d8d78903c253b4e5ec5ad7c4f93e1b5f7affc2"
Commit 2ca29096 authored by Paul's avatar Paul
Browse files

Refactor to use correct descriptors

parent 6fda1d3e
...@@ -56,18 +56,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__( ...@@ -56,18 +56,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp> #include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp> #include <migraphx/kernels/ck_gemm.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
using gemm_t = CKDeviceGemm<${instance}>;
extern "C" { extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) { make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
ck_gemm<gemm_t>(a, b, c); ck_gemm<CKDeviceGemm<${instance}>>(a, b, c);
}); });
} }
......
...@@ -36,16 +36,18 @@ namespace migraphx { ...@@ -36,16 +36,18 @@ namespace migraphx {
template <class G, class A, class B, class C> template <class G, class A, class B, class C>
__device__ void ck_gemm(const A& a, const B& b, const C& c) __device__ void ck_gemm(const A& a, const B& b, const C& c)
{ {
constexpr auto a_desc = to_ck_tensor<A>(); constexpr G gemm{};
constexpr auto b_desc = to_ck_tensor<B>();
constexpr auto c_desc = to_ck_tensor<C>();
constexpr auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_desc);
using GridwiseGemm = typename G::template Make<decltype(a_desc), decltype(b_desc), decltype(c_desc)>; constexpr auto a_grid_desc_ak0_m_ak1 = gemm.MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
// static_assert(GridwiseGemm::CheckValidity(a_desc, b_desc, c_desc, block_2_ctile_map)); constexpr auto b_grid_desc_bk0_n_bk1 = gemm.MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), decltype(b_grid_desc_bk0_n_bk1), decltype(c_grid_desc_m_n)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_desc); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
constexpr auto shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr auto shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ char p_shared_block[shared_block_size]; __shared__ char p_shared_block[shared_block_size];
...@@ -56,11 +58,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -56,11 +58,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
b.data(), b.data(),
c.data(), c.data(),
p_shared_block, p_shared_block,
G::AOp(), gemm.a_element_op,
G::BOp(), gemm.b_element_op,
G::COp(), gemm.c_element_op,
a_desc, a_grid_desc_ak0_m_ak1,
b_desc, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
} }
......
...@@ -149,11 +149,275 @@ template <typename ALayout, ...@@ -149,11 +149,275 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock, ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()> ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()
struct CKDeviceGemm >
struct CKDeviceGemm
{ {
template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N> static constexpr auto I0 = ck::Number<0>{};
using Make = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
template<class Descriptor>
static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
{
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_right_pad_transform(MRaw, MPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
template<class Descriptor>
static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
{
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(NRaw, NPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(NRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template<class Descriptor>
static constexpr auto
MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
{
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<8, 8, 8>());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<8, 8, 8>());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N<8, 8, 8>());
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -197,18 +461,10 @@ struct CKDeviceGemm ...@@ -197,18 +461,10 @@ struct CKDeviceGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
static constexpr auto AOp() { return AElementwiseOperation{}; } AElementwiseOperation a_element_op{};
static constexpr auto BOp() { return BElementwiseOperation{}; } BElementwiseOperation b_element_op{};
static constexpr auto COp() { return CElementwiseOperation{}; } CElementwiseOperation c_element_op{};
// return block_id to C matrix tile idx (m0, n0) mapping
template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
}; };
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment