Commit 54dd72b6 authored by Alan Turner's avatar Alan Turner
Browse files

Debug previous commit

parent c1e7454d
......@@ -54,36 +54,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = b.lens()[1];
auto n = a.lens()[0];
auto k = a.lens()[1];
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
if(a.lens()[1] >= 2048)
if(a.lens()[1] > 1024)
return false;
return true;
// std::cout << a << std::endl;
// std::cout << b << std::endl;
// printf("m, n, k: %zu, %zu, %zu\n", m, n, k);
// if ((m == 1414 and n == 2048 and k == 512) or
// (m == 4096 and n == 2048 and k == 1414) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 160 and n == 2048 and k == 64) or
// (m == 512 and n == 2048 and k == 512) or
// (m == 39488 and n == 2048 and k == 512) or
// (m == 5120 and n == 2048 and k == 512))
// return true;//(a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
// //b.lens()[1] % 8 == 0);
// return false;
}
struct find_ck_gemm
{
// Find a convolution followed by a pointwise operation.
// Find a gemm that can be replaced with a ck_gemm
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......
......@@ -77,7 +77,7 @@ static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y -
static std::size_t block_size_index = 15;
static std::size_t padding_index = 11;
static std::size_t padding_index = 13;
static std::size_t get_block_size(const std::vector<std::string>& s)
{
......@@ -115,7 +115,6 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
std::cout << inputs[0] << std::endl << inputs[1] << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
......@@ -151,6 +150,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto m = c_shape.lens().front();
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
auto i = v.get("tuning_val", get_tuning_for(inputs));
auto& instance = get_instance(i, [&](const auto& x) -> bool {
......
......@@ -36,12 +36,7 @@ namespace migraphx {
template <class G, class A, class B, class E, class... Ds>
__device__ void ck_gemm(A a, B b, E e, Ds... ds)
{
constexpr const auto a_grid_desc_ak0_m_ak1 =
G::MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr const auto b_grid_desc_bk0_n_bk1 =
G::MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr const auto c_grid_desc_m_n = G::MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr const auto block_2_ctile_map = G::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<B>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment