Commit 969be85c authored by Paul's avatar Paul
Browse files

Format

parent bfad5a5b
......@@ -45,18 +45,14 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and
b.lens()[0] % 8 == 0 and b.lens()[1] % 8 == 0);
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
}
struct find_ck_gemm
{
// Find a convolution followed by a pointwise operation.
auto matcher() const
{
return match::name("dot")(is_ck_gemm().bind("gemm"));
}
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......@@ -67,14 +63,9 @@ struct find_ck_gemm
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm{});
}
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm{}); }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -77,11 +77,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__";
std::size_t int_div_ceil(std::size_t x, std::size_t y)
{
return (x + y - 1) / y;
}
std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
std::size_t get_grid_size(std::size_t m, std::size_t mpb, std::size_t n, std::size_t npb)
{
......@@ -100,25 +96,72 @@ namespace fs = std::filesystem;
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
const std::vector<std::string> instances{
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8"};
const std::vector<block_settings> params {
{256, 256, 128},
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, "
"16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 4>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
"32, 1, 4>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, "
"16, 1, 8>, 8"};
const std::vector<block_settings> params{{256, 256, 128},
{256, 256, 128},
{256, 128, 256},
{256, 128, 256},
......@@ -140,7 +183,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
int i = 4;
if (contains(v, "tuning_val"))
if(contains(v, "tuning_val"))
i = v.at("tuning_val").to<int>();
assert(i >= 0 and i < instances.size());
......@@ -165,7 +208,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto sa = inputs.front().strides().front();
auto sb = inputs.at(1).strides().front();
auto sc = inputs.back().strides().front();
auto src = interpolate_string(ck_gemm_kernel, {{"instance", instances[i]},
auto src = interpolate_string(ck_gemm_kernel,
{{"instance", instances[i]},
{"m", to_string(m)},
{"k", to_string(k)},
{"n", to_string(n)},
......
......@@ -62,7 +62,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
......@@ -73,7 +72,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
__host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
ck::index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
......@@ -120,7 +120,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
__host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
......@@ -177,13 +178,11 @@ template <typename ALayout,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()
>
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CKDeviceGemm
{
//template<ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA>
static constexpr auto
MakeAGridDescriptor_AK0_M_AK1()
// template<ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1()
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>)
......@@ -287,9 +286,8 @@ struct CKDeviceGemm
}
}
//template<ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB>
static constexpr auto
MakeBGridDescriptor_BK0_N_BK1()
// template<ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1()
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -393,9 +391,8 @@ struct CKDeviceGemm
}
}
//template<ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC>
static constexpr auto
MakeCGridDescriptor_M_N()
// template<ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC>
static constexpr auto MakeCGridDescriptor_M_N()
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment