Commit f83139de authored by Alan Turner's avatar Alan Turner
Browse files

add ck_gemm_add_add_gelu fusion

parent 21b14ff2
......@@ -28,4 +28,4 @@ half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@639147432b6922bd8e4051ba751e4e63dd4eb196 -X header
ROCmSoftwarePlatform/composable_kernel -X header
......@@ -91,6 +91,7 @@ add_library(migraphx_gpu
deconvolution.cpp
device_name.cpp
elu.cpp
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
......
......@@ -138,12 +138,12 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t groups = (n + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
printf("n: %zu\n", n);
printf("over: %zu\n", over);
printf("max_global: %zu\n", max_global);
printf("groups: %zu\n", groups);
printf("max_blocks: %zu\n", max_blocks);
printf("nglobal: %zu\n", nglobal);
// printf("n: %zu\n", n);
// printf("over: %zu\n", over);
// printf("max_global: %zu\n", max_global);
// printf("groups: %zu\n", groups);
// printf("max_blocks: %zu\n", max_blocks);
// printf("nglobal: %zu\n", nglobal);
return std::min(nglobal, n);
};
}
......
......@@ -95,8 +95,6 @@ struct block_settings
int npb;
};
namespace fs = std::filesystem;
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
const std::vector<std::string> instances{
......@@ -135,6 +133,38 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{128, 64, 128},
{128, 64, 128}};
const std::unordered_map<std::string, int> tuning_lookup{
{"1024 2048 1416", 9},
{"1416 2048 512", 9},
{"4096 2048 1416", 5},
{"512 2048 2048", 6},
{"2048 2048 512", 9},
{"2048 2048 2048", 9},
{"512 2048 2048", 6},
{"2048 2048 512", 9},
{"160 2048 64", 6},
{"160 2048 64", 7},
{"160 2048 2048", 6},
{"288 2048 768", 7},
{"1024 2048 5120", 9},
{"512 2048 512", 9},
{"39488 2048 512", 5},
{"512 2048 512", 9},
{"5120 2048 512", 5},
{"512 2048 6536", 9},
{"39488 2048 512", 5},
{"32 384 160", 6},
{"32 856 160", 6},
{"32 8 160", 6},
{"512 2048 39488", 9},
{"8192 2048 3200", 5},
{"4096 2048 4096", 5},
{"4096 2048 4096", 5},
{"4096 2048 4096", 5},
{"4096 2048 4096", 5},
{"4096 2048 4096", 5},
{"8 2048 9224", 6}};
std::vector<std::string> names() const { return {"ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
......@@ -147,12 +177,25 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
hip_compile_options options;
auto out_s = inputs.back();
auto m = out_s.lens().front();
auto n = out_s.lens().back();
auto k = inputs.front().lens().back();
std::string mnk = to_string(m) + " " + to_string(n) + " " + to_string(k);
auto itr = tuning_lookup.find(mnk);
std::cout << mnk << std::endl;
if (itr != tuning_lookup.end())
{
i = tuning_lookup.at(mnk);
std::cout << " i: " << i << std::endl;
}
else
std::cout << "i: " << i << std::endl;
auto b_s = params[i];
auto block_size = b_s.bs;
auto m_per_block = b_s.mpb;
auto n_per_block = b_s.npb;
auto m = out_s.lens().front();
auto n = out_s.lens().back();
auto grid_size = get_grid_size(m, m_per_block, n, n_per_block);
options.set_launch_params(v, grid_size * block_size, block_size);
......@@ -161,7 +204,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = "ck_gemm_kernel";
options.virtual_inputs = inputs;
auto k = inputs.front().lens().back();
auto sa = inputs.front().strides().front();
auto sb = inputs.at(1).strides().front();
auto sc = inputs.back().strides().front();
......
......@@ -46,7 +46,6 @@ namespace gpu {
static const char* const ck_gemm_aag_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_fusion_inclusion.hpp>
#include <migraphx/kernels/ck_gemm_add_add_gelu.hpp>
#include <migraphx/kernels/gemm_aag_instance.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
......@@ -56,12 +55,20 @@ static const char* const ck_gemm_aag_kernel = R"__migraphx__(
namespace migraphx {
using gemm_t = ${instance}, ${m}, ${k}, ${n}, ${sa}, ${sb}, ${sd0}, ${sd1}, ${se}>;
constexpr __device__ gemm_t ckdg{};
using GridwiseGemm = decltype(ckdg.gridwisegemm);
extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* d_p, void* e_p)
__global__ void ck_gemm_aag_kernel(void* a_p, void* b_p, void* d0_p, void* d1_p, void* e_p)
{
make_tensors()(a_p, b_p, d_p, e_p)([](auto a_t, auto b_t, auto d_t, auto e_t) {
ck_gemm_add_add_gelu(a_t, b_t, c_t, p_t);
make_tensors()(a_p, b_p, d0_p, d1_p, e_p)([&](auto a_t, auto b_t, auto d0_t, auto d1_t, auto e_t) {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
make_tensors()(p_shared)([&](auto p_t){
ck_gemm_add_add_gelu<gemm_t>(a_t, b_t, d0_t, d1_t, e_t, p_t);
});
});
}
......@@ -71,88 +78,16 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* d_p, void* e_p)
)__migraphx__";
// std::string kernel_p1 = R"__migraphx__(
// #include <migraphx/kernels/ck_gemm_includes.hpp>
// #include <migraphx/kernels/ck_gemm2.hpp>
// #include <migraphx/kernels/ops.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/generic_constant.hpp>
// #include <args.hpp>
// #include <hip/hip_runtime_api.h>
// namespace migraphx {
// using gemm = CKDeviceGemm)__migraphx__";
// std::string tuning_vals = R"__migraphx__(< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>;)__migraphx__";
// std::string kernel_p2 = R"__migraphx__(
// extern "C" {
// __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
// {
// constexpr gemm htp{};
// using hGridwiseGemm = decltype(htp.gg);
// make_tensors()(a_p, b_p, c_p)([&](auto a_t, auto b_t, auto c_t) {
// constexpr ck::index_t shared_block_size =
// hGridwiseGemm::GetSharedMemoryNumberOfByte();
// __shared__ char p_shared_block[shared_block_size];
// make_tensors()(p_shared_block)([&](auto p_t) {
// ck_gemm(a_t, b_t, c_t, p_t, htp);
// });
// });
// }
// }
// } // namespace migraphx
// )__migraphx__";
// std::string kernel_string = kernel_p1 + tuning_vals + kernel_p2;
static std::string gemm_aag_instance = R"__migraphx__(
#ifndef MIGRAPHX_GUARD_GEMM_AAG_INSTANCE_HPP
#define MIGRAPHX_GUARD_GEMM_AAG_INSTANCE_HPP
#include <migraphx/kernels/ck_fusion_inclusion.hpp>
namespace migraphx {
using ck_op =
CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>;
// CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
} // namespace migraphx
#endif
)__migraphx__";
// static const char* const ck_gemm_kernel = kernel_string.c_str();
namespace fs = std::filesystem;
std::size_t int_divide_ceil(std::size_t x, std::size_t y)
static std::size_t int_div_ceil(std::size_t x, std::size_t y)
{
return (x + y - 1) / y;
}
static std::size_t get_grid_size(std::size_t m, std::size_t mpb, std::size_t n, std::size_t npb)
{
return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
struct block_settings
{
int bs;
......@@ -160,56 +95,103 @@ struct block_settings
int npb;
};
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
struct ck_gemm_add_add_gelu_compiler : compiler<ck_gemm_add_add_gelu_compiler>
{
std::vector<std::string> names() const { return {"ck_gemm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
const std::vector<std::string> instances{
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8",
"CK_DeviceGemmMultipleD< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8"};
const std::vector<block_settings> params {
{256, 256, 128},
{256, 256, 128},
{256, 128, 256},
{256, 128, 256},
{256, 128, 128},
{256, 128, 128},
{256, 128, 64},
{256, 128, 64},
{256, 64, 128},
{256, 64, 128},
{128, 128, 128},
{128, 128, 128},
{128, 128, 64},
{128, 128, 64},
{128, 64, 128},
{128, 64, 128}};
std::vector<std::string> names() const { return {"ck_gemm_add_add_gelu", "gpu::ck_gemm_add_add_gelu"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
// create GEMM instance header
std::string path = fs::absolute(__FILE__);
path = path.substr(0, path.find_last_of("\\/"));
path = path.substr(0, path.find_last_of("\\/"));
path += "/kernels/include/migraphx/kernels/gemm_aag_instance.hpp";
std::ofstream out(path);
out << gemm_aag_instance;
out.close();
//std::cout << ck_gemm_kernel << std::endl;
int i = 4;
if (contains(v, "tuning_val"))
i = v.at("tuning_val").to<int>();
assert(i >= 0 and i < instances.size());
hip_compile_options options;
auto out_s = inputs.back();
block_settings b_s{256, 256, 128};
// block_settings b_s{256, 256, 128};
// block_settings b_s{256, 128, 256};
// block_settings b_s{256, 128, 256};
// block_settings b_s{256, 128, 128};
// block_settings b_s{256, 128, 128};
// block_settings b_s{256, 128, 64};
// block_settings b_s{256, 128, 64};
// block_settings b_s{256, 64, 128};
// block_settings b_s{256, 64, 128};
// block_settings b_s{128, 128, 128};
// block_settings b_s{128, 128, 128};
// block_settings b_s{128, 128, 64};
// block_settings b_s{128, 128, 64};
// block_settings b_s{128, 64, 128};
// block_settings b_s{128, 64, 128};
auto m = out_s.lens().front();
auto n = out_s.lens().back();
auto k = inputs.front().lens().back();
std::string mnk = to_string(m) + " " + to_string(n) + " " + to_string(k);
// auto itr = tuning_lookup.find(mnk);
// std::cout << mnk << std::endl;
// if (itr != tuning_lookup.end())
// {
// i = tuning_lookup.at(mnk);
// std::cout << " i: " << i << std::endl;
// }
// else
// std::cout << "i: " << i << std::endl;
auto b_s = params[i];
auto block_size = b_s.bs;
auto m_per_block = b_s.mpb;
auto n_per_block = b_s.npb;
auto m = out_s.lens().front();
auto n = out_s.lens().back();
auto grid_size = int_divide_ceil(m, m_per_block) * int_divide_ceil(n, n_per_block);
auto grid_size = get_grid_size(m, m_per_block, n, n_per_block);
printf("m, n, grid, global: %i, %i, %i, %i\n", int(m), int(n), int(grid_size), int(grid_size * block_size * 2));
printf("out elm: %i\n", int(out_s.elements()));
options.set_launch_params(v, compute_global_for(ctx, grid_size * block_size, 2), block_size);
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "ck_gemm_aag_kernel";
options.virtual_inputs = inputs;
return compile_hip_code_object(ck_gemm_aag_kernel, options);
auto sa = inputs.front().strides().front();
auto sb = inputs.at(1).strides().front();
//auto sc = inputs.back().strides().front();
auto sd0 = inputs.at(2).strides().front();
auto sd1 = inputs.at(3).strides().front();
auto se = inputs.at(4).strides().front();
printf("strides: %zu, %zu, %zu, %zu, %zu\n", sa, sb, sd0, sd1, se);
auto src = interpolate_string(ck_gemm_aag_kernel, {{"instance", instances[i]},
{"m", to_string(m)},
{"k", to_string(k)},
{"n", to_string(n)},
{"sa", to_string(sa)},
{"sb", to_string(sb)},
{"sd0", to_string(sd0)},
{"sd1", to_string(sd1)},
{"se", to_string(se)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
......
......@@ -39,13 +39,22 @@
// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
// #include "ck/ck.hpp"
// #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
// #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
// #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
// #include "ck/utility/common_header.hpp"
// #include "ck/tensor_description/tensor_descriptor.hpp"
// #include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
// #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -55,6 +64,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace migraphx {
......@@ -89,6 +100,71 @@ using S = ck::Sequence<Is...>;
using namespace ck; //
//////////////////////////
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
......@@ -102,7 +178,7 @@ template <typename ALayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -132,13 +208,26 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::index_t MRaw,
ck::index_t KRaw,
ck::index_t NRaw,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideD0,
ck::index_t StrideD1,
ck::index_t StrideE,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct CK_DeviceGemmMultipleD
{
static constexpr std::array<index_t, 2> MRaws{};
static constexpr std::array<index_t, 2> NRaws{};
static constexpr std::array<index_t, 2> DsStride{StrideD0, StrideD1};
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
constexpr static auto MakeAGridDescriptor_M_K()
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -156,7 +245,7 @@ struct CK_DeviceGemmMultipleD
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
constexpr static auto MakeBGridDescriptor_N_K()
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -175,7 +264,7 @@ struct CK_DeviceGemmMultipleD
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
constexpr static auto MakeEGridDescriptor_M_N()
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
......@@ -193,24 +282,92 @@ struct CK_DeviceGemmMultipleD
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
template <typename ELay>
constexpr static auto MakeEGridDescriptor_M_N0()
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideD0, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideD0));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
template <typename ELay>
constexpr static auto MakeEGridDescriptor_M_N1()
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideD1, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideD1));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
constexpr static auto MakeDsGridDescriptor_M_N()
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
return MakeEGridDescriptor_M_N<DLayout>();
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K());
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K());
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N())>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>());
template<class T, class U/* , class V */>
constexpr static auto Populate_D_Ptr(T& p_ds_grid_, U& p_ds_grid/* , V& ds_grid_desc_m_n_ */)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
//using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
// if constexpr(i == 0)
// {
// ds_grid_desc_m_n_.At(i) =
// MakeEGridDescriptor_M_N0<DLayout>();
// }
// else
// {
// ds_grid_desc_m_n_.At(i) =
// MakeEGridDescriptor_M_N1<DLayout>();
// }
});
// return make_tuple(MakeEGridDescriptor_M_N0<remove_cvref_t<tuple_element_t<0, DsLayout>>>(),
// MakeEGridDescriptor_M_N1<remove_cvref_t<tuple_element_t<1, DsLayout>>>());
}
constexpr static auto MakeDsDescTuple()
{
return make_tuple(MakeEGridDescriptor_M_N0<remove_cvref_t<tuple_element_t<0, DsLayout>>>(),
MakeEGridDescriptor_M_N1<remove_cvref_t<tuple_element_t<1, DsLayout>>>());
}
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
......@@ -266,12 +423,26 @@ struct CK_DeviceGemmMultipleD
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
// return block_id to E matrix tile idx (m0, n0) mapping
__device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n_);
}
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
GridwiseGemm gridwisegemm{};
}
static constexpr DsGridDesc_M_N ds_grid_desc_m_n{};
static constexpr EGridDesc_M_N e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>();
static constexpr Block2ETileMap block_2_etile_map = MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
static constexpr GridwiseGemm gridwisegemm{};
static constexpr AElementwiseOperation a_element_op{};
static constexpr BElementwiseOperation b_element_op{};
static constexpr CDEElementwiseOperation cde_element_op{};
};
} // namespace migraphx
......
......@@ -34,78 +34,133 @@
namespace migraphx {
template <class A, class B, class C, class D, class E, class F, class G, class H, class I, class J, class K>
__device__ void fake_op(A, B, C, D, E, F, G, H, I, J, K)
template <class A, class B, class C, class D, class E, class F, class G,
class H, class I, class J, class K, class L, class M>
__device__ void fake_op(A, B, C, D, E, F, G, H, I, J, K, L, M)
{
}
template <class T, class U, class V, class W>
__device__ void ck_gemm_add_add_gelu(const T& a_t, const U& b_t, const V& d_t, const W& e_t)
{
// constexpr static gemm ckdg{};
// using GridwiseGemm = decltype(ckdg.gridwisegemm);
// constexpr auto alens = get_shape_c<T>{}.lens;
// constexpr auto m = alens[0];
// constexpr auto k = alens[1];
// constexpr auto blens = get_shape_c<U>{}.lens;
// constexpr auto n = blens[1];
// constexpr auto astrides = get_shape_c<T>{}.strides;
// constexpr auto as = astrides[0];
// constexpr auto bstrides = get_shape_c<U>{}.strides;
// constexpr auto bs = bstrides[0];
// constexpr auto cstrides = get_shape_c<V>{}.strides;
// constexpr auto cs = cstrides[0];
// constexpr auto a_grid_desc_ak0_m_ak1 = ckdg.MakeAGridDescriptor_AK0_M_AK1<
// static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as)>();
// constexpr auto b_grid_desc_bk0_n_bk1 = ckdg.MakeBGridDescriptor_BK0_N_BK1<
// static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs)>();
// constexpr auto c_grid_desc_m_n = ckdg.MakeCGridDescriptor_M_N<
// static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs)>();
// constexpr auto block_2_ctile_map = ckdg.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
// static_assert(GridwiseGemm::CheckValidity(
// a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map));
// template <class G, class T, class U, class V, class W, class X>
// __device__ void ck_gemm_add_add_gelu(const T& , const U&, const V& , const V& , const W& , X& )
// {
// constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
// GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
// }
template <class G, class T, class U, class V, class W, class X>
__device__ void ck_gemm_add_add_gelu(const T& a_t, const U& b_t, const V& d0_t, const V& d1_t, const W& e_t, X& p_t)
{
constexpr static G ckdg{};
using GridwiseGemm = decltype(ckdg.gridwisegemm);
// tensor descriptors for problem definiton
constexpr auto a_grid_desc_m_k = ckdg.MakeAGridDescriptor_M_K();
constexpr auto b_grid_desc_n_k = ckdg.MakeBGridDescriptor_N_K();
//constexpr auto ds_grid_desc_m_n = ckdg.ds_grid_desc_m_n;
constexpr auto e_grid_desc_m_n = ckdg.e_grid_desc_m_n;
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
// block-to-e-tile map
constexpr auto block_2_etile_map = ckdg.block_2_etile_map;
// constexpr auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
// constexpr auto a_element_op = ckdg.a_element_op;
// constexpr auto b_element_op = ckdg.b_element_op;
// constexpr auto c_element_op = ckdg.c_element_op;
// element-wise op
constexpr auto a_element_op = ckdg.a_element_op;
constexpr auto b_element_op = ckdg.b_element_op;
constexpr auto cde_element_op = ckdg.cde_element_op;
constexpr std::size_t NumDTensor = 2;
std::array<const void*, NumDTensor> p_ds_grid{d0_t.data(), d1_t.data()};
typename GridwiseGemm::DsGridPointer p_ds_grid_{};
ckdg.Populate_D_Ptr(p_ds_grid_, p_ds_grid/* , ds_grid_desc_m_n */);
constexpr auto ds_grid_desc_m_n = ckdg.MakeDsDescTuple();
// populate desc for Ds/E
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map));
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
constexpr auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
constexpr bool HasMainKBlockLoop = true;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(),
p_ds_grid_,
e_t.data(),
p_t.data(),
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
else
{
constexpr bool HasMainKBlockLoop = false;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(),
p_ds_grid_,
e_t.data(),
p_t.data(),
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
// constexpr auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
// if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
// {
// constexpr bool HasMainKBlockLoop = true;
// GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
// b_t.data(),
// c_t.data(),
// p_t.data(),
// a_element_op,
// b_element_op,
// c_element_op,
// a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// c_grid_desc_mblock_mperblock_nblock_nperblock,
// block_2_ctile_map);
// fake_op(a_t.data(),
// b_t.data(),
// p_ds_grid_,
// e_t.data(),
// p_t.data(),
// a_element_op,
// b_element_op,
// cde_element_op,
// a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// block_2_etile_map);
// }
// else
// {
// constexpr bool HasMainKBlockLoop = false;
// GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
// b_t.data(),
// c_t.data(),
// p_t.data(),
// a_element_op,
// b_element_op,
// c_element_op,
// a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// c_grid_desc_mblock_mperblock_nblock_nperblock,
// block_2_ctile_map);
// fake_op(a_t.data(),
// b_t.data(),
// p_ds_grid_,
// e_t.data(),
// p_t.data(),
// a_element_op,
// b_element_op,
// cde_element_op,
// a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// block_2_etile_map);
// }
}
} // namespace migraphx
......
......@@ -272,17 +272,17 @@ struct miopen_apply
assert(refs.size() == 2);
auto a_lens = refs.front()->get_shape().lens();
auto b_lens = refs.back()->get_shape().lens();
if (refs.front()->get_shape().lens().size() == 2 and
not refs.front()->get_shape().transposed() and
not refs.back()->get_shape().transposed() and
a_lens[0] % 8 == 0 and a_lens[1] % 8 == 0 and
b_lens[0] % 8 == 0 and b_lens[1] % 8 == 0)
{
auto it = mod->replace_instruction(
ins, make_op("ck_gemm"), refs);
return insert_precompile_op(it);
}
else
// if (refs.front()->get_shape().lens().size() == 2 and
// not refs.front()->get_shape().transposed() and
// not refs.back()->get_shape().transposed() and
// a_lens[0] % 8 == 0 and a_lens[1] % 8 == 0 and
// b_lens[0] % 8 == 0 and b_lens[1] % 8 == 0)
// {
// auto it = mod->replace_instruction(
// ins, make_op("ck_gemm"), refs);
// return insert_precompile_op(it);
// }
// else
{
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
......
......@@ -55,6 +55,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
......@@ -134,6 +135,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
fuse_mlir{&ctx},
dead_code_elimination{},
fuse_ck{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_ck_gemm : verify_program<test_ck_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
unsigned long m = 256;
unsigned long k = m;//4096;
unsigned long n = k;//4096;
migraphx::shape m1_shape{migraphx::shape::half_type, {m, k}};
migraphx::shape m2_shape{migraphx::shape::half_type, {k, n}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
// migraphx::shape m1_shape{migraphx::shape::half_type, {1}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1}};
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, {1}});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, {1}});
// l1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {m, k}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {k, n}}}), l2);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
return p;
}
};
// struct test_ck_gemm : verify_program<test_ck_gemm>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// unsigned long m = 3; unsigned long k = 3; unsigned long n = 3;
// migraphx::shape m1_shape{migraphx::shape::half_type, {m, k}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {k, n}};
// std::vector<float> v1(m * k, 1);
// //std::iota(v1.begin(), v1.end(), 1);
// std::vector<float> v2(k * n, 1);
// std::iota(v2.begin(), v2.end(), 1);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// // auto l1 = mm->add_parameter("1", m1_shape);
// // auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
// return p;
// }
// };
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// struct ck_elementwise_half : verify_program<ck_elementwise_half>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {1, 1, 2, 2, 2}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1, 1, 1, 2, 1}};
// std::vector<float> v1(8, 1);
// std::vector<float> v2(2);
// std::iota(v2.begin(), v2.end(), 1);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// l2 = mm->add_instruction(
// migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 2, 2, 2}}}), l2);
// //l2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
// mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
// mm->debug_print();
// return p;
// }
// };
// struct ck_elementwise_half : verify_program<ck_elementwise_half>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {2, 384, 3072}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1, 384, 1}};
// std::vector<float> v1(2*384*3072, 1);
// std::vector<float> v2(384, 2.54);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// l2 = mm->add_instruction(
// migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), l2);
// //l2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
// mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
// mm->debug_print();
// return p;
// }
// };
struct ck_elementwise_half : verify_program<ck_elementwise_half>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {2, 384, 3072}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 384, 1}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), l2);
mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
return p;
}
};
// struct ck_elementwise_half : verify_program<ck_elementwise_half>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {3072}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1}};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// l2 = mm->add_instruction(
// migraphx::make_op("multibroadcast", {{"out_lens", {3072}}}), l2);
// //l2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
// mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
// //mm->debug_print();
// return p;
// }
// };
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_elementwise : verify_program<ck_elementwise>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {20}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m1_shape);
mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment