Commit 8752d11f authored by Alan Turner's avatar Alan Turner
Browse files

Add specialization for lens not divisible by 8

parent 2d18473f
......@@ -54,12 +54,31 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = b.lens()[1];
auto n = a.lens()[0];
auto k = a.lens()[1];
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
if (a.lens()[1] >= 2048)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
return true;
// std::cout << a << std::endl;
// std::cout << b << std::endl;
// printf("m, n, k: %zu, %zu, %zu\n", m, n, k);
// if ((m == 1414 and n == 2048 and k == 512) or
// (m == 4096 and n == 2048 and k == 1414) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 160 and n == 2048 and k == 64) or
// (m == 512 and n == 2048 and k == 512) or
// (m == 39488 and n == 2048 and k == 512) or
// (m == 5120 and n == 2048 and k == 512))
// return true;//(a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
// //b.lens()[1] % 8 == 0);
// return false;
}
struct find_ck_gemm
......
......@@ -40,7 +40,7 @@
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
const std::vector<std::string>&
std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx {
......@@ -77,6 +77,8 @@ static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y -
static std::size_t block_size_index = 13;
static std::size_t padding_index = 11;
static std::size_t get_block_size(const std::vector<std::string>& s)
{
return std::stoull(s[block_size_index]);
......@@ -89,6 +91,11 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t
return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
static void set_padding(std::vector<std::string>& s, const std::string p)
{
s[padding_index] = p;
}
template <class F, class Action>
auto action_decorate(F f, Action action)
{
......@@ -111,6 +118,8 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
std::cout << inputs[0] << std::endl
<< inputs[1] << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
......@@ -152,12 +161,26 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto sc = c_shape.strides().front();
auto i = v.get("tuning_val", get_tuning_for(inputs));
const auto& instance = get_instance(i, [&](const auto& x) -> bool {
auto& instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and
get_type(b_shape) == x[4] and get_type(c_shape) == x[5];
});
const bool pad_m = m % 8;
const bool pad_n = n % 8;
const bool pad_k = k % 8;
if (pad_m or pad_n or pad_k)
{
std::string padding_t = "ck::tensor_operation::device::GemmSpecialization::";
padding_t += pad_m ? "M" : "";
padding_t += pad_n ? "N" : "";
padding_t += pad_k ? "K" : "";
padding_t += "Padding";
set_padding(instance, padding_t);
}
hip_compile_options options;
auto block_size = get_block_size(instance);
auto grid_size = get_grid_size(instance, m, n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment