"vscode:/vscode.git/clone" did not exist on "e9143cd7e644239991b259a2f17a0a035e2b8c5a"
Commit 4e64e2c2 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent f3fcfcc7
......@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax));
// }
......
......@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul =
match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add =
match::name("add")(match::any_of[match::inputs()](mul));
auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add = match::name("add")(match::any_of[match::inputs()](mul));
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
......
......@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
if (not s.transposed())
if(not s.transposed())
return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2] ?
"ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
return lens[lens.size() - 1] > lens[lens.size() - 2]
? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
......@@ -197,10 +198,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
std::array<char, 4> keys{'M', 'N', 'K', 'O'};
// config (m0, n0, k0, n1)
std::array<std::size_t, 4> config{
c_shape.lens()[rank - 2], b_shape.lens()[rank - 2], a_shape.lens().back(), c_shape.lens().back()};
std::array<std::size_t, 4> config{c_shape.lens()[rank - 2],
b_shape.lens()[rank - 2],
a_shape.lens().back(),
c_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto tuning_val =
v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
......
......@@ -50,11 +50,11 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
// b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto scale =
// mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto bias =
// mm->add_instruction(migraphx::make_op("add"), scale, zero); auto softmax =
// mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx::program p;
......@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g);
g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g);
auto a = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
auto a = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
......
......@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function = -1
@contextlib.contextmanager
def tmp_file(dump=None):
tmp_name = None
......@@ -101,7 +102,6 @@ def parse_log(f):
yield config
def benchmark_log(f, n):
result = []
logs = parse_log(f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment