Commit 4e64e2c2 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent f3fcfcc7
...@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm ...@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); // match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw = // auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias"); // match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); // auto softmax =
// return match::name("dot")(is_ck_gemm().bind("gemm2"))( // match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax)); // match::any_of[match::inputs()](softmax));
// } // }
......
...@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm ...@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm
{ {
auto gemm1 = auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale"); auto add = match::name("add")(match::any_of[match::inputs()](mul));
auto add =
match::name("add")(match::any_of[match::inputs()](mul));
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax"); auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))( return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax)); match::any_of[match::inputs()](softmax));
......
...@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
if (not s.transposed()) if(not s.transposed())
return "ck::tensor_layout::gemm::RowMajor"; return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens(); auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2] ? return lens[lens.size() - 1] > lens[lens.size() - 2]
"ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
} }
static std::string get_type(const shape& s) static std::string get_type(const shape& s)
...@@ -197,10 +198,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -197,10 +198,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
std::array<char, 4> keys{'M', 'N', 'K', 'O'}; std::array<char, 4> keys{'M', 'N', 'K', 'O'};
// config (m0, n0, k0, n1) // config (m0, n0, k0, n1)
std::array<std::size_t, 4> config{ std::array<std::size_t, 4> config{c_shape.lens()[rank - 2],
c_shape.lens()[rank - 2], b_shape.lens()[rank - 2], a_shape.lens().back(), c_shape.lens().back()}; b_shape.lens()[rank - 2],
a_shape.lens().back(),
c_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape})); auto tuning_val =
v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool { auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
......
...@@ -50,11 +50,11 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -50,11 +50,11 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
// // a = one; // // a = one;
// // b = one; // // b = one;
// // b1 = one; // // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); // b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto scale =
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); // mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto bias =
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero); // mm->add_instruction(migraphx::make_op("add"), scale, zero); auto softmax =
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias); // mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1); // mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx::program p; migraphx::program p;
...@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g); g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g);
g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g); g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g);
auto a = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g); auto a = mm->add_instruction(
auto b = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b1 = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g); auto b = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
......
...@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib ...@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function = -1 ck_function = -1
@contextlib.contextmanager @contextlib.contextmanager
def tmp_file(dump=None): def tmp_file(dump=None):
tmp_name = None tmp_name = None
...@@ -101,7 +102,6 @@ def parse_log(f): ...@@ -101,7 +102,6 @@ def parse_log(f):
yield config yield config
def benchmark_log(f, n): def benchmark_log(f, n):
result = [] result = []
logs = parse_log(f) logs = parse_log(f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment