Commit 07167910 authored by turneram's avatar turneram
Browse files

Add xdl fp16 gemm

parent d1f5753a
set(CGET_PREFIX "/code/AMDMIGraphX/AMDMIGraphX/cget")
set(CMAKE_PREFIX_PATH "/code/AMDMIGraphX/AMDMIGraphX/cget")
if (${CMAKE_VERSION} VERSION_LESS "3.6.0")
include_directories(SYSTEM ${CGET_PREFIX}/include)
else ()
set(CMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES "${CGET_PREFIX}/include")
set(CMAKE_C_STANDARD_INCLUDE_DIRECTORIES "${CGET_PREFIX}/include")
endif()
if (CMAKE_CROSSCOMPILING)
list(APPEND CMAKE_FIND_ROOT_PATH "/code/AMDMIGraphX/AMDMIGraphX/cget")
endif()
if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "/code/AMDMIGraphX/AMDMIGraphX/cget")
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
set(CMAKE_CXX_ENABLE_PARALLEL_BUILD_FLAG "/MP")
endif()
if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS "ON" CACHE BOOL "")
endif()
set(CMAKE_FIND_FRAMEWORK "LAST" CACHE STRING "")
set(CMAKE_INSTALL_RPATH "${CGET_PREFIX}/lib" CACHE STRING "")
...@@ -28,4 +28,4 @@ half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0 ...@@ -28,4 +28,4 @@ half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@639147432b6922bd8e4051ba751e4e63dd4eb196 -X header ROCmSoftwarePlatform/composable_kernel -X header
...@@ -362,6 +362,8 @@ foreach(_unused RANGE 2) ...@@ -362,6 +362,8 @@ foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Add ck includes # Add ck includes
find_path(CK_INCLUDE_PATH ck/ck.hpp) find_path(CK_INCLUDE_PATH ck/ck.hpp)
set (CK_INCLUDE_PATH "/code/AMDMIGraphX/AMDMIGraphX/depend/cget/include/")
message(STATUS "CK path: ${CK_INCLUDE_PATH}")
string(APPEND HIP_COMPILER_FLAGS " -isystem ${CK_INCLUDE_PATH}") string(APPEND HIP_COMPILER_FLAGS " -isystem ${CK_INCLUDE_PATH}")
endforeach() endforeach()
......
...@@ -45,7 +45,7 @@ using namespace migraphx::gpu::gen; // NOLINT ...@@ -45,7 +45,7 @@ using namespace migraphx::gpu::gen; // NOLINT
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_elementwise_kernel = R"__migraphx__( static const char* const ck_elementwise_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp> #include <migraphx/kernels/ck_elementwise2.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
......
...@@ -40,43 +40,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,43 +40,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
// static const char* const ck_gemm_kernel = R"__migraphx__(
// #include <migraphx/kernels/ck_gemm.hpp>
// #include <migraphx/kernels/ops.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/generic_constant.hpp>
// #include <args.hpp>
// #include <hip/hip_runtime_api.h>
// namespace migraphx {
// extern "C" {
// __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
// {
// // hipDeviceProp_t hdp{};
// // printf("Shared mem: %i\n", int(hdp.sharedMemPerBlock));
// // make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// // ck_gemm(xs...);
// // });
// make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
// __shared__ float p_shared_block[512]; //[(a_t.get_shape().elements() +
// b_t.get_shape().elements()) * 2]; ck_gemm(a_t, b_t, c_t, p_shared_block);
// // make_tensors()(p_shared_block)([&](auto p_t) {
// // ck_gemm(a_t, b_t, c_t, p_t);
// // });
// });
// }
// }
// } // namespace migraphx
// )__migraphx__";
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_includes.hpp> #include <migraphx/kernels/ck_gemm_includes.hpp>
#include <migraphx/kernels/ck_gemm2.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
...@@ -88,111 +54,15 @@ namespace migraphx { ...@@ -88,111 +54,15 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{ {
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) { make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
constexpr auto alens = get_shape_c<decltype(a_t)>{}.lens;
constexpr auto m = alens[0];
constexpr auto k = alens[1];
constexpr auto blens = get_shape_c<decltype(b_t)>{}.lens;
constexpr auto n = blens[1];
constexpr auto astrides = get_shape_c<decltype(a_t)>{}.strides;
constexpr auto as = astrides[0];
constexpr auto bstrides = get_shape_c<decltype(b_t)>{}.strides;
constexpr auto bs = bstrides[0];
constexpr auto cstrides = get_shape_c<decltype(c_t)>{}.strides;
constexpr auto cs = cstrides[0];
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
auto b_grid_desc_k0_n0_n1_k1 =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true;
constexpr ck::index_t shared_block_size = constexpr ck::index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(float); hGridwiseGemm::GetSharedMemoryNumberOfByte() /* / sizeof(float) */;
__shared__ float p_shared_block[shared_block_size]; __shared__ void* p_shared_block[shared_block_size];
GridwiseGemm::Run(a_t.data(), make_tensors()(p_shared_block)([&](auto p_t) {
b_t.data(), ck_gemm(a_t, b_t, c_t, p_t);
c_t.data(), });
p_shared_block,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
// using AGridDesc_K0_M0_M1_K1 =
// decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
// using BGridDesc_K0_N0_N1_K1 =
// decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
// using CGridDesc_M0_M10_M11_N0_N10_N11 =
// decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
// using DefaultBlock2CTileMap =
// decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// const auto kernel = ck::kernel_gemm_dl_v1r3<GridwiseGemm,
// ADataType,
// CDataType,
// remove_reference_t<AGridDesc_K0_M0_M1_K1>,
// remove_reference_t<BGridDesc_K0_N0_N1_K1>,
// remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
// remove_reference_t<DefaultBlock2CTileMap>,
// true,
// true>;
// kernel(a_t.data(),
// b_t.data(),
// c_t.data(),
// a_grid_desc_k0_m0_m1_k1,
// b_grid_desc_k0_n0_n1_k1,
// c_grid_desc_m0_m10_m11_n0_n10_n11,
// block_2_ctile_map);
}); });
} }
...@@ -202,6 +72,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) ...@@ -202,6 +72,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"; )__migraphx__";
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
std::vector<std::string> names() const { return {"ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm"}; }
......
...@@ -31,9 +31,10 @@ ...@@ -31,9 +31,10 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include "ck/ck.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include <ck/tensor_operation/gpu/device/device_base.hpp>
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
...@@ -212,6 +213,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) ...@@ -212,6 +213,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
AScalarPerVector, AScalarPerVector,
BScalarPerVector, BScalarPerVector,
CScalarPerVector>; CScalarPerVector>;
auto op = Add{}; auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, b_desc, c_desc, op); GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, b_desc, c_desc, op);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#include <stdio.h>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
// #include "ck/device_utility/device_prop.hpp"
// #include "ck/device_utility/kernel_launch.hpp"
//#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include <ck/ck.hpp>
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
namespace migraphx {
using ABDataType = ck::half_t;
using CDataType = ck::half_t;
using ElementwiseFunctor = ck::half_t;
static constexpr auto I0 = ck::Number<0>{};
// template <typename InDataTypeTuple,
// typename OutDataTypeTuple,
// typename ElementwiseOperation,
// index_t NumDim,
// index_t MPerThread,
// typename InScalarPerVectorSeq,
// typename OutScalarPerVectorSeq>
// struct CKDeviceElementwise
// {
// __device__ constexpr auto GenerateInDataTypePointerTuple()
// {
// return generate_tuple(
// [&](auto I) {
// using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
// return static_cast<const DataType*>(nullptr);
// },
// Number<NumInput>{});
// };
// __device__ constexpr auto GenerateOutDataTypePointerTuple()
// {
// return generate_tuple(
// [&](auto I) {
// using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
// return static_cast<DataType*>(nullptr);
// },
// Number<NumOutput>{});
// };
// template <class Desc_M>
// __device__ constexpr auto PadDescriptor_M_1d(Desc_M desc_m)
// {
// auto gridSize = 72;
// auto blockSize = 1024;
// auto MPerThread = 8;
// const auto M = desc_m.GetLength(I0);
// const ck::index_t loop_step = gridSize * blockSize * MPerThread;
// const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
// const auto desc_m_pad =
// transform_tensor_descriptor(desc_m,
// make_tuple(ck::make_right_pad_transform(M, pad)),
// make_tuple(ck::Sequence<0>{}),
// make_tuple(ck::Sequence<0>{}));
// return desc_m_pad;
// }
// template <class L, class S>
// __device__ constexpr auto MakeDescriptor_M(const L& lengths, const S& strides)
// {
// auto tupleOfShape = generate_tuple(
// [&](auto I) { return static_cast<ck::index_t>(lengths[I]); }, ck::Number<ndim>{});
// auto tupleOfStride = generate_tuple(
// [&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<ndim>{});
// const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// // merge nd to 1d desc - [s0 * s1 * ...]
// if constexpr(ndim > 1)
// {
// const auto desc_m = transform_tensor_descriptor(
// desc,
// make_tuple(make_merge_transform(tupleOfShape)),
// make_tuple(generate_sequence_v2([&](auto I) { return I; }, ck::Number<ndim>{})),
// make_tuple(ck::Sequence<0>{}));
// return PadDescriptor_M_1d(desc_m);
// }
// else
// {
// return PadDescriptor_M_1d(desc);
// }
// }
// template <index_t TupleSize>
// __device__ constexpr auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
// {
// return generate_tuple(
// [&](auto) {
// if constexpr(NumDim > 1)
// {
// return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
// }
// else
// {
// return MakeDescriptor_M({1}, {1}, 1, 1);
// };
// },
// Number<TupleSize>{});
// };
// };
struct Add
{
template <typename Y, typename X0, typename X1>
__device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = x0 + x1;
};
};
struct Mul
{
template <typename Y, typename X0, typename X1>
__device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = x0 * x1;
};
};
struct Div
{
template <typename Y, typename X0, typename X1>
__device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = x0 / x1;
};
};
using InDataTypeTuple = ck::Tuple<ABDataType, ABDataType>;
using OutDataTypeTuple = ck::Tuple<CDataType>;
using ElementwiseOperation = Add;
static constexpr auto MPerThread = 8;
using InScalarPerVectorSeq = ck::Sequence<1, 8>;
using OutScalarPerVectorSeq = ck::Sequence<8>;
// using DeviceElementwiseAddInstance =
// ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
// ck::Tuple<CDataType>,
// Add,
// 3,
// 8,
// ck::Sequence<1, 8>,
// ck::Sequence<8>>;
template <class T, class U, class V>
__device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
{
// auto idx = make_index();
constexpr auto a_lens = get_shape_c<T>{}.lens;
constexpr auto a_strides = get_shape_c<T>{}.strides;
constexpr ck::index_t ndim = a_lens.size();
constexpr auto b_lens = get_shape_c<U>{}.lens;
constexpr auto b_strides = get_shape_c<U>{}.strides;
constexpr ck::index_t b_ndim = b_lens.size();
constexpr auto c_lens = get_shape_c<V>{}.lens;
constexpr auto c_strides = get_shape_c<V>{}.strides;
constexpr ck::index_t c_ndim = c_lens.size();
assert(b_ndim == ndim and c_ndim == ndim);
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ck::Tuple<CDataType>,
Add,
ndim,
8,
ck::Sequence<1, 8>,
ck::Sequence<8>>;
using shapes_t = std::array<ck::index_t, 3>;
//shapes_t lengths_abc;
//copy(c_lens.begin(), c_lens.end(), lengths_abc);
shapes_t lengths_abc = {c_lens[0], c_lens[1], c_lens[2]};
//constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
constexpr auto strides_a = static_cast<shapes_t>(a_strides);
constexpr auto strides_b = static_cast<shapes_t>(b_strides);
constexpr auto strides_c = static_cast<shapes_t>(c_strides);
std::array<const void*, 2> input = {a_t.data(),
b_t.data()};
std::array<void*, 1> output = {c_t.data()};
auto ck_add = DeviceElementwiseAddInstance{};
auto argument = ck_add.MakeArgumentPointer(
lengths_abc, {strides_a, strides_b}, {strides_c}, input, output, Add{});
using InGrid1dDescTuple = decltype(ck_add.GenerateInOutGrid1dDescTuple(ck::Number<ndim>{}));
using OutGrid1dDescTuple = decltype(ck_add.GenerateInOutGrid1dDescTuple(ck::Number<ndim>{}));
using InDataTypePointerTuple = decltype(ck_add.GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(ck_add.GenerateOutDataTypePointerTuple());
using GridwiseElementwise = ck::GridwiseElementwise_1D<InGrid1dDescTuple,
OutGrid1dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
GridwiseElementwise::Run(argument.in_grid_1d_desc_tuple_,
argument.out_grid_1d_desc_tuple_,
argument.in_dev_buffers_,
argument.out_dev_buffers_,
argument.elementwise_op_);
}
} // namespace migraphx
#endif
...@@ -29,193 +29,12 @@ ...@@ -29,193 +29,12 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include "ck/utility/common_header.hpp" #include <migraphx/kernels/ck_includes.hpp>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace migraphx { namespace migraphx {
static constexpr auto I0 = ck::Number<0>{}; template <class T, class U, class V, class W>
static constexpr auto I1 = ck::Number<1>{}; __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
static constexpr auto I4 = ck::Number<4>{};
static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// Values hard-coded by CK
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t BlockSize = 256;
static constexpr ck::index_t K0PerBlock = 16;
static constexpr ck::index_t M1PerThread = 4;
static constexpr ck::index_t N1PerThread = 4;
static constexpr ck::index_t KPerThread = 1;
using M1N1ThreadClusterM1Xs = S<8, 2>;
using M1N1ThreadClusterN1Xs = S<8, 2>;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = S<2, 1, 4, 1>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = S<8, 1, 32, 1>;
using ABlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>;
using ABlockTransferSrcAccessOrder = S<0, 3, 1, 2>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>;
using ABlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = S<2, 1, 4, 1>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = S<8, 1, 32, 1>;
using BBlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>;
using BBlockTransferSrcAccessOrder = S<0, 3, 1, 2>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>;
using BBlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>;
using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5;
static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4;
static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA)
{
assert(K % K1 == 0);
const ck::index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(M, K), ck::make_tuple(StrideA, I1));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(M, K), ck::make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_right_pad_transform(M, PadM)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
}
}
static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N, ck::index_t StrideB)
{
assert(K % K1 == 0);
const ck::index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(K, N), ck::make_tuple(StrideB, I1));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(K, N), ck::make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
}
}
static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck::index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(M, N), ck::make_tuple(StrideC, I1));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(M, N), ck::make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(c_grid_desc_m_n,
ck::make_tuple(ck::make_right_pad_transform(M, PadM),
ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(ck::make_pass_through_transform(M), ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// template <class T, class U, class V, class W>
// __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
template <class T, class U, class V>
__device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, float* p_t)
{ {
constexpr auto alens = get_shape_c<T>{}.lens; constexpr auto alens = get_shape_c<T>{}.lens;
constexpr auto m = alens[0]; constexpr auto m = alens[0];
...@@ -238,63 +57,106 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, float* p_t) ...@@ -238,63 +57,106 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, float* p_t)
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs)); static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs)); static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 = if(idx.global == 0)
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); {
auto b_grid_desc_k0_n0_n1_k1 = printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m_k1.GetLength(I0)), int(a_grid_desc_k0_m_k1.GetLength(I1)), int(a_grid_desc_k0_m_k1.GetLength(I2)));
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1); printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n_k1.GetLength(I0)), int(b_grid_desc_k0_n_k1.GetLength(I1)), int(b_grid_desc_k0_n_k1.GetLength(I2)));
auto c_grid_desc_m0_m10_m11_n0_n10_n11 = printf("c_grid_desc_m_n{%i, %i}\n", int(c_grid_desc_m_n.GetLength(I0)), int(c_grid_desc_m_n.GetLength(I1)));
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n); }
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n); AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11;
DefaultBlock2CTileMap block_2_ctile_map;
if(true or GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n))
{
//printf("Is valid\n");
a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
b_grid_desc_k0_n0_n1_k1 =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
}
else
{
//printf("Not valid\n");
}
if(idx.global == 0)
{
printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2)));
printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
}
constexpr bool HasMainKBlockLoop = true; const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
constexpr bool HasDoubleTailKBlockLoop = true; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
auto num_bytes = GridwiseGemm::GetSharedMemoryNumberOfByte(); const bool has_double_tail_k_block_loop =
printf("Bytes: %i\n", int(num_bytes)); GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
GridwiseGemm::Run(a_t.data(), if(has_main_k_block_loop && has_double_tail_k_block_loop)
b_t.data(), {
c_t.data(), constexpr bool HasMainKBlockLoop = true;
/* p_t.data(), */ p_t, constexpr bool HasDoubleTailKBlockLoop = true;
a_grid_desc_k0_m0_m1_k1, GridwiseGemm::Run(a_t.data(),
b_grid_desc_k0_n0_n1_k1, b_t.data(),
c_grid_desc_m0_m10_m11_n0_n10_n11, c_t.data(),
block_2_ctile_map, p_t.data(),
ck::integral_constant<bool, HasMainKBlockLoop>{}, a_grid_desc_k0_m0_m1_k1,
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = false;
GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
constexpr bool HasMainKBlockLoop = false;
constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
else
{
constexpr bool HasMainKBlockLoop = false;
constexpr bool HasDoubleTailKBlockLoop = false;
GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
} }
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp>
namespace migraphx {
template <class T, class U, class V, class W>
__device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
{
static gemm tp{};
using GridwiseGemm = decltype(tp.gg);
constexpr auto alens = get_shape_c<T>{}.lens;
constexpr auto m = alens[0];
constexpr auto k = alens[1];
constexpr auto blens = get_shape_c<U>{}.lens;
constexpr auto n = blens[1];
constexpr auto astrides = get_shape_c<T>{}.strides;
constexpr auto as = astrides[0];
constexpr auto bstrides = get_shape_c<U>{}.strides;
constexpr auto bs = bstrides[0];
constexpr auto cstrides = get_shape_c<V>{}.strides;
constexpr auto cs = cstrides[0];
auto idx = make_index();
if(idx.global == 0)
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
constexpr auto a_grid_desc_ak0_m_ak1 = tp.MakeAGridDescriptor_AK0_M_AK1(static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
constexpr auto b_grid_desc_bk0_n_bk1 = tp.MakeBGridDescriptor_BK0_N_BK1(static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
constexpr auto c_grid_desc_m_n = tp.MakeCGridDescriptor_M_N(static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
/* constexpr */ auto block_2_ctile_map = tp.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
if(idx.global == 0)
{
printf("a_grid_desc_ak0_m_ak1{%i, %i, %i}\n", int(a_grid_desc_ak0_m_ak1.GetLength(I0)), int(a_grid_desc_ak0_m_ak1.GetLength(I1)), int(a_grid_desc_ak0_m_ak1.GetLength(I2)));
printf("b_grid_desc_bk0_n_bk1{%i, %i, %i}\n", int(b_grid_desc_bk0_n_bk1.GetLength(I0)), int(b_grid_desc_bk0_n_bk1.GetLength(I1)), int(b_grid_desc_bk0_n_bk1.GetLength(I2)));
printf("c_grid_desc_m_n{%i, %i}\n", int(c_grid_desc_m_n.GetLength(I0)), int(c_grid_desc_m_n.GetLength(I1)));
}
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock{};
if(true or GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
}
// if(idx.global == 0)
// {
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2)));
// printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// }
const auto K =
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
auto a_element_op = tp.a_element_op;
auto b_element_op = tp.b_element_op;
auto c_element_op = tp.c_element_op;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
constexpr bool HasMainKBlockLoop = true;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
else
{
constexpr bool HasMainKBlockLoop = false;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
}
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
//#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
//#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
namespace migraphx {
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
static constexpr auto I4 = ck::Number<4>{};
static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// using ALayout = Row;
// using BLayout = Row;
// using CLayout = Row;
// using ADataType = ck::half_t;
// using BDataType = ck::half_t;
// using CDataType = ck::half_t;
// using GemmAccDataType = float;
// using CShuffleDataType = ck::half_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// using AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
// using BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
// using CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// Values hard-coded by CK
// static constexpr ck::index_t NumGemmKPrefetchStage = 1;
// static constexpr ck::index_t BlockSize = 256;
// static constexpr ck::index_t MPerBlock = 256;
// static constexpr ck::index_t NPerBlock = 128;
// static constexpr ck::index_t KPerBlock = 32;
// static constexpr ck::index_t AK1 = 8;
// static constexpr ck::index_t BK1 = 2;
// static constexpr ck::index_t MPerXDL = 32;
// static constexpr ck::index_t NPerXDL = 32;
// static constexpr ck::index_t MXdlPerWave = 4;
// static constexpr ck::index_t NXdlPerWave = 2;
// using ABlockTransferThreadClusterLengths_AK0_M_AK1 = S<4, 64, 1>;
// using ABlockTransferThreadClusterArrangeOrder = S<1, 0, 2>;
// using ABlockTransferSrcAccessOrder = S<1, 0, 2>;
// static constexpr ck::index_t ABlockTransferSrcVectorDim = 2;
// static constexpr ck::index_t ABlockTransferSrcScalarPerVector = 8;
// static constexpr ck::index_t ABlockTransferDstScalarPerVector_AK1 = 8;
// static constexpr ck::index_t ABlockLdsExtraM = 1;
// using BBlockTransferThreadClusterLengths_BK0_N_BK1 = S<8, 32, 1>;
// using BBlockTransferThreadClusterArrangeOrder = S<0, 2, 1>;
// using BBlockTransferSrcAccessOrder = S<0, 2, 1>;
// static constexpr ck::index_t BBlockTransferSrcVectorDim = 1;
// static constexpr ck::index_t BBlockTransferSrcScalarPerVector = 4;
// static constexpr ck::index_t BBlockTransferDstScalarPerVector_BK1 = 2;
// static constexpr ck::index_t BBlockLdsExtraN = 0;
// static constexpr ck::index_t CShuffleMXdlPerWavePerShuffle = 1;
// static constexpr ck::index_t CShuffleNXdlPerWavePerShuffle = 1;
// using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = S<1, 32, 1, 8>;
// static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
ck::index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ __device__ constexpr ck::index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const ck::index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
ck::index_t idx_N0 = block_1d_id % N0;
ck::index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
ck::index_t idx_M00 = idx_M0 / M01_;
ck::index_t idx_M01 = idx_M0 % M01_;
ck::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
ck::index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct TuningParams
{
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, KRaw),
ck::make_tuple(StrideA, I1));
}
else if constexpr(ck::is_same_v<ck::tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, KRaw),
ck::make_tuple(I1, StrideA));
}
}();
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both M and K
//assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad M, but not K
//assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_right_pad_transform(MRaw, MPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad K, but not M
//assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw), ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
//assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(NRaw, KRaw),
ck::make_tuple(I1, StrideB));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(NRaw, KRaw),
ck::make_tuple(StrideB, I1));
}
}();
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad both N and K
//assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(NRaw, NPad),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{
// pad N, but not K
//assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad K, but not N
//assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(NRaw), ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
//assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static constexpr auto MakeCGridDescriptor_M_N(ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, NRaw),
ck::make_tuple(StrideC, I1));
}
else if constexpr(is_same<ck::tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(ck::make_tuple(MRaw, NRaw),
ck::make_tuple(I1, StrideC));
}
}();
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad), ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw), ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
GridwiseGemm gg{};
AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{};
CElementwiseOperation c_element_op{};
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
};
using gemm = TuningParams
// clang-format off
//| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// FP32:
// < Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
static gemm htp{};
using hGridwiseGemm = decltype(htp.gg);
} // namespace migraphx
#endif
...@@ -36,8 +36,7 @@ ...@@ -36,8 +36,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace migraphx { namespace migraphx {
...@@ -53,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{}; ...@@ -53,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Col; using ALayout = Row;//Col;
using BLayout = Row; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
...@@ -62,6 +61,10 @@ using BDataType = float; ...@@ -62,6 +61,10 @@ using BDataType = float;
using CDataType = float; using CDataType = float;
using AccDataType = float; using AccDataType = float;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -212,5 +215,60 @@ using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); ...@@ -212,5 +215,60 @@ using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -27,16 +27,43 @@ ...@@ -27,16 +27,43 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
// struct test_ck_gemm : verify_program<test_ck_gemm>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::float_type, {3840, 4096}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
// return p;
// }
// };
struct test_ck_gemm : verify_program<test_ck_gemm> struct test_ck_gemm : verify_program<test_ck_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::half_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3}}; migraphx::shape m2_shape{migraphx::shape::half_type, {3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); std::vector<float> v1(2*3, 1);
auto l2 = mm->add_parameter("2", m2_shape); std::iota(v1.begin(), v1.end(), 1);
std::vector<float> v2(3*4, 1);
//std::iota(v2.begin(), v2.end(), 1);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
//l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2); mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment