Commit 4e64e2c2 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent f3fcfcc7
...@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm ...@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); // match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw = // auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias"); // match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); // auto softmax =
// return match::name("dot")(is_ck_gemm().bind("gemm2"))( // match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax)); // match::any_of[match::inputs()](softmax));
// } // }
......
...@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm ...@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm
{ {
auto gemm1 = auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale"); auto add = match::name("add")(match::any_of[match::inputs()](mul));
auto add =
match::name("add")(match::any_of[match::inputs()](mul));
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax"); auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))( return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax)); match::any_of[match::inputs()](softmax));
......
...@@ -111,8 +111,8 @@ struct instance ...@@ -111,8 +111,8 @@ struct instance
void set_gemm(const std::string& s) void set_gemm(const std::string& s)
{ {
assert(params[15] == "ck::tensor_operation::device::GemmSpecialization::Default" or assert(params[15] == "ck::tensor_operation::device::GemmSpecialization::Default" or
params[15] == "ck::tensor_operation::device::GemmSpecialization::MNKOPadding"); params[15] == "ck::tensor_operation::device::GemmSpecialization::MNKOPadding");
params[15] = s; params[15] = s;
} }
...@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
if (not s.transposed()) if(not s.transposed())
return "ck::tensor_layout::gemm::RowMajor"; return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens(); auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2] ? return lens[lens.size() - 1] > lens[lens.size() - 2]
"ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
} }
static std::string get_type(const shape& s) static std::string get_type(const shape& s)
...@@ -185,23 +186,26 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -185,23 +186,26 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto b1_shape = inputs[2]; auto b1_shape = inputs[2];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto m = a_shape.lens()[0]; auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1]; auto k = a_shape.lens()[1];
auto n = c_shape.lens()[1]; auto n = c_shape.lens()[1];
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
std::array<char, 4> keys{'M', 'N', 'K', 'O'}; std::array<char, 4> keys{'M', 'N', 'K', 'O'};
// config (m0, n0, k0, n1) // config (m0, n0, k0, n1)
std::array<std::size_t, 4> config{ std::array<std::size_t, 4> config{c_shape.lens()[rank - 2],
c_shape.lens()[rank - 2], b_shape.lens()[rank - 2], a_shape.lens().back(), c_shape.lens().back()}; b_shape.lens()[rank - 2],
a_shape.lens().back(),
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape})); c_shape.lens().back()};
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
auto tuning_val =
v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
...@@ -220,8 +224,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -220,8 +224,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
gemm_type += "Padding"; gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type); ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config); auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2, auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(), c_shape.lens().rend(),
std::size_t{1}, std::size_t{1},
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
......
...@@ -935,122 +935,122 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std:: ...@@ -935,122 +935,122 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"8", "8",
"false", "false",
"std::ratio<1, 8>"}, "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor", // {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor", // "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t", // "ck::half_t",
// "ck::half_t", // "ck::half_t",
// "ck::half_t", // "ck::half_t",
// "ck::half_t", // "ck::half_t",
// "float", // "float",
// "ck::half_t", // "ck::half_t",
// "ck_passthrough", // "ck_passthrough",
// "ck_passthrough", // "ck_passthrough",
// "ck_scale", // "ck_scale",
// "ck_passthrough", // "ck_passthrough",
// "ck_passthrough", // "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", // "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1", // "1",
// "256", // "256",
// "128", // "128",
// "256", // "256",
// "40", // "40",
// "64", // "64",
// "32", // "32",
// "4", // "4",
// "4", // "4",
// "2", // "2",
// "32", // "32",
// "32", // "32",
// "1", // "1",
// "8", // "8",
// "2", // "2",
// "ck::Sequence<2,128,1>", // "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "2", // "2",
// "4", // "4",
// "4", // "4",
// "false", // "false",
// "ck::Sequence<2,128,1>", // "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "2", // "2",
// "4", // "4",
// "4", // "4",
// "false", // "false",
// "ck::Sequence<16,16,1>", // "ck::Sequence<16,16,1>",
// "ck::Sequence<0,2,1>", // "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>", // "ck::Sequence<0,2,1>",
// "1", // "1",
// "4", // "4",
// "2", // "2",
// "false", // "false",
// "1", // "1",
// "2", // "2",
// "ck::Sequence<1,32,1,8>", // "ck::Sequence<1,32,1,8>",
// "8", // "8",
// "false", // "false",
// "std::ratio<1, 8>"}, // "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor", // {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor", // "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t", // "ck::half_t",
// "ck::half_t", // "ck::half_t",
// "ck::half_t", // "ck::half_t",
// "ck::half_t", // "ck::half_t",
// "float", // "float",
// "ck::half_t", // "ck::half_t",
// "ck_passthrough", // "ck_passthrough",
// "ck_passthrough", // "ck_passthrough",
// "ck_scale", // "ck_scale",
// "ck_passthrough", // "ck_passthrough",
// "ck_passthrough", // "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", // "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1", // "1",
// "256", // "256",
// "128", // "128",
// "256", // "256",
// "40", // "40",
// "128", // "128",
// "32", // "32",
// "4", // "4",
// "4", // "4",
// "2", // "2",
// "32", // "32",
// "32", // "32",
// "1", // "1",
// "8", // "8",
// "4", // "4",
// "ck::Sequence<2,128,1>", // "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "2", // "2",
// "4", // "4",
// "4", // "4",
// "false", // "false",
// "ck::Sequence<2,128,1>", // "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>",
// "2", // "2",
// "4", // "4",
// "4", // "4",
// "false", // "false",
// "ck::Sequence<8,32,1>", // "ck::Sequence<8,32,1>",
// "ck::Sequence<0,2,1>", // "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>", // "ck::Sequence<0,2,1>",
// "1", // "1",
// "4", // "4",
// "2", // "2",
// "false", // "false",
// "1", // "1",
// "2", // "2",
// "ck::Sequence<1,32,1,8>", // "ck::Sequence<1,32,1,8>",
// "8", // "8",
// "false", // "false",
// "std::ratio<1, 8>"}, // "std::ratio<1, 8>"},
{"ck::tensor_layout::gemm::RowMajor", {"ck::tensor_layout::gemm::RowMajor",
"ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::ColumnMajor",
"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor",
......
...@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_shape = get_shape_c<B>{}; constexpr const auto b_shape = get_shape_c<B>{};
constexpr const auto n = b_shape.lens[1]; constexpr const auto n = b_shape.lens[1];
constexpr const auto sb = b_shape.strides[1]; // col-major constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1(); constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1; constexpr const auto BK0 = k / BK1;
...@@ -85,8 +85,8 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -85,8 +85,8 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_shape = get_shape_c<B1>{}; constexpr const auto b1_shape = get_shape_c<B1>{};
constexpr const auto k1 = b1_shape.lens[0]; constexpr const auto k1 = b1_shape.lens[0];
constexpr const auto n1 = b1_shape.lens[1]; constexpr const auto n1 = b1_shape.lens[1];
constexpr const auto sb1 = b1_shape.strides[0]; // row-major constexpr const auto sb1 = b1_shape.strides[0]; // row-major
constexpr const auto B1K1 = gemm.get_B1K1(); constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1; constexpr const auto B1K0 = k1 / B1K1;
......
...@@ -50,15 +50,15 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -50,15 +50,15 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
// // a = one; // // a = one;
// // b = one; // // b = one;
// // b1 = one; // // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); // b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto scale =
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); // mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto bias =
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero); // mm->add_instruction(migraphx::make_op("add"), scale, zero); auto softmax =
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias); // mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1); // mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
size_t batch = 2; size_t batch = 2;
migraphx::shape m1_shape{migraphx::shape::half_type, {batch, 384, 2304}}; migraphx::shape m1_shape{migraphx::shape::half_type, {batch, 384, 2304}};
migraphx::shape m2_shape{migraphx::shape::half_type, {batch, 12, 384, 384}}; migraphx::shape m2_shape{migraphx::shape::half_type, {batch, 12, 384, 384}};
...@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g); g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g);
g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g); g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g);
auto a = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g); auto a = mm->add_instruction(
auto b = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b1 = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g); auto b = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
......
...@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib ...@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function = -1 ck_function = -1
@contextlib.contextmanager @contextlib.contextmanager
def tmp_file(dump=None): def tmp_file(dump=None):
tmp_name = None tmp_name = None
...@@ -99,7 +100,6 @@ def parse_log(f): ...@@ -99,7 +100,6 @@ def parse_log(f):
config = json.loads(line) config = json.loads(line)
ck_function = 1 ck_function = 1
yield config yield config
def benchmark_log(f, n): def benchmark_log(f, n):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment