Commit 127393f4 authored by turneram's avatar turneram
Browse files

Call gemm from kernel

parent 9d12476e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct ck_elementwise
{
std::string name() const { return "ck_elementwise"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
par_for(output_shape.elements(),
[&](const auto i) { output[i] = input1[i] + input2[i]; });
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -47,14 +47,24 @@ static const char* const ck_gemm_kernel = R"__migraphx__( ...@@ -47,14 +47,24 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) { hipDeviceProp_t hdp{};
ck_gemm(xs...); printf("Shared mem: %i\n", int(hdp.sharedMemPerBlock));
// make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// ck_gemm(xs...);
// });
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
__shared__ void* p_shared_block[(a_t.get_shape().elements()/* + b_t.get_shape().elements() */) * 2];
make_tensors()(p_shared_block)([&](auto p_t) {
ck_gemm(a_t, b_t, c_t, p_t);
});
}); });
} }
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -39,143 +41,249 @@ ...@@ -39,143 +41,249 @@
namespace migraphx { namespace migraphx {
// static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = ck::Number<0>{};
// static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = ck::Number<1>{};
// static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = ck::Number<2>{};
// static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = ck::Number<3>{};
// static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = ck::Number<4>{};
// static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = ck::Number<5>{};
// static constexpr auto K1Number = Number<1>{};
static constexpr ck::index_t K1 = 1;
// static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto K1Number = ck::Number<K1>{};
// {
// assert(K % K1 == 0); using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// const index_t K0 = K / K1; using ALayout = Col;
using BLayout = Row;
// const auto a_grid_desc_m_k = [&]() { using CLayout = Row;
// if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
// { using ADataType = float;
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); using BDataType = float;
// } using CDataType = float;
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) using AccDataType = float;
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// }
// }(); template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// { // Values hard-coded by CK
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 128;
// return transform_tensor_descriptor( static constexpr ck::index_t BlockSize = 256;
// a_grid_desc_m_k, static constexpr ck::index_t K0PerBlock = 16;
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), static constexpr ck::index_t M1PerThread = 4;
// make_right_pad_transform(M, PadM)), static constexpr ck::index_t N1PerThread = 4;
// make_tuple(Sequence<1>{}, Sequence<0>{}), static constexpr ck::index_t KPerThread = 1;
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); using M1N1ThreadClusterM1Xs = S<8, 2>;
// } using M1N1ThreadClusterN1Xs = S<8, 2>;
// else using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = S<2, 1, 4, 1>;
// { using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = S<8, 1, 32, 1>;
// return transform_tensor_descriptor( using ABlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>;
// a_grid_desc_m_k, using ABlockTransferSrcAccessOrder = S<0, 3, 1, 2>;
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>;
// make_pass_through_transform(M)), using ABlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>;
// make_tuple(Sequence<1>{}, Sequence<0>{}), using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>;
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = S<2, 1, 4, 1>;
// } using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = S<8, 1, 32, 1>;
// } using BBlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>;
using BBlockTransferSrcAccessOrder = S<0, 3, 1, 2>;
// static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>;
// { using BBlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>;
// assert(K % K1 == 0); using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>;
using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>;
// const index_t K0 = K / K1; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5;
static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4;
// const auto b_grid_desc_k_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA)
// } {
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) assert(K % K1 == 0);
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); const ck::index_t K0 = K / K1;
// }
// }(); const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value)
// if constexpr(GemmSpec == GemmSpecialization::MNPadding) {
// { return make_naive_tensor_descriptor(ck::make_tuple(M, K), ck::make_tuple(StrideA, I1));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; }
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, ALayout>::value)
// return transform_tensor_descriptor( {
// b_grid_desc_k_n, return make_naive_tensor_descriptor(ck::make_tuple(M, K), ck::make_tuple(I1, StrideA));
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), }
// make_right_pad_transform(N, PadN)), }();
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
// } {
// else const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// {
// return transform_tensor_descriptor( return transform_tensor_descriptor(
// b_grid_desc_k_n, a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
// make_pass_through_transform(N)), ck::make_right_pad_transform(M, PadM)),
// make_tuple(Sequence<0>{}, Sequence<1>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
// } }
// } else
{
// static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) return transform_tensor_descriptor(
// { a_grid_desc_m_k,
// const auto c_grid_desc_m_n = [&]() { ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
// if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) ck::make_pass_through_transform(M)),
// { ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
// } }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) }
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N, ck::index_t StrideB)
// } {
// }(); assert(K % K1 == 0);
// if constexpr(GemmSpec == GemmSpecialization::MNPadding) const ck::index_t K0 = K / K1;
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto b_grid_desc_k_n = [&]() {
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
// return transform_tensor_descriptor( return make_naive_tensor_descriptor(ck::make_tuple(K, N), ck::make_tuple(StrideB, I1));
// c_grid_desc_m_n, }
// make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, BLayout>::value)
// make_tuple(Sequence<0>{}, Sequence<1>{}), {
// make_tuple(Sequence<0>{}, Sequence<1>{})); return make_naive_tensor_descriptor(ck::make_tuple(K, N), ck::make_tuple(I1, StrideB));
// } }
// else }();
// {
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
// return transform_tensor_descriptor( {
// c_grid_desc_m_n, const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}), return transform_tensor_descriptor(
// make_tuple(Sequence<0>{}, Sequence<1>{})); b_grid_desc_k_n,
// } ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
// } ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
template <class T, class U, class V> ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
__device__ void ck_gemm(const T& /* a_t */, const U& /* b_t */, const V& /* c_t */) }
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
}
}
static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck::index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(M, N), ck::make_tuple(StrideC, I1));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(M, N), ck::make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(ck::make_right_pad_transform(M, PadM), ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(ck::make_pass_through_transform(M), ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
template <class T, class U, class V, class W>
__device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
{ {
constexpr auto alens = get_shape_c<T>{}.lens; constexpr auto alens = get_shape_c<T>{}.lens;
constexpr auto m = alens[0]; constexpr auto m = alens[0];
constexpr auto k = alens[1]; constexpr auto k = alens[1];
constexpr auto alens = get_shape_c<U>{}.lens; constexpr auto blens = get_shape_c<U>{}.lens;
constexpr auto n = alens[1]; constexpr auto n = blens[1];
constexpr auto astrides = get_shape_c<T>{}.strides; constexpr auto astrides = get_shape_c<T>{}.strides;
constexpr auto as = astrides[1]; constexpr auto as = astrides[0];
constexpr auto bstrides = get_shape_c<U>{}.strides; constexpr auto bstrides = get_shape_c<U>{}.strides;
constexpr auto bs = bstrides[1]; constexpr auto bs = bstrides[0];
constexpr auto cstrides = get_shape_c<V>{}.strides; constexpr auto cstrides = get_shape_c<V>{}.strides;
constexpr auto cs = cstrides[1]; constexpr auto cs = cstrides[0];
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs)); auto idx = make_index();
if (idx.global == 0)
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
auto b_grid_desc_k0_n0_n1_k1 =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(), b_t.data(), c_t.data(), p_t.data(), a_grid_desc_k0_m0_m1_k1, b_grid_desc_k0_n0_n1_k1, c_grid_desc_m0_m10_m11_n0_n10_n11, block_2_ctile_map, ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -33,8 +33,8 @@ struct ck_gemm : verify_program<ck_gemm> ...@@ -33,8 +33,8 @@ struct ck_gemm : verify_program<ck_gemm>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}}; migraphx::shape m1_shape{migraphx::shape::float_type, {128, 256}};
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 20}}; migraphx::shape m2_shape{migraphx::shape::float_type, {256, 256}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment