Commit 4e64e2c2 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent f3fcfcc7
......@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax));
// }
......
......@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul =
match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add =
match::name("add")(match::any_of[match::inputs()](mul));
auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add = match::name("add")(match::any_of[match::inputs()](mul));
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
......
......@@ -111,8 +111,8 @@ struct instance
void set_gemm(const std::string& s)
{
assert(params[15] == "ck::tensor_operation::device::GemmSpecialization::Default" or
params[15] == "ck::tensor_operation::device::GemmSpecialization::MNKOPadding");
assert(params[15] == "ck::tensor_operation::device::GemmSpecialization::Default" or
params[15] == "ck::tensor_operation::device::GemmSpecialization::MNKOPadding");
params[15] = s;
}
......@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
if (not s.transposed())
if(not s.transposed())
return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2] ?
"ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
return lens[lens.size() - 1] > lens[lens.size() - 2]
? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
......@@ -185,23 +186,26 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto b1_shape = inputs[2];
auto c_shape = inputs.back();
auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1];
auto n = c_shape.lens()[1];
auto c_shape = inputs.back();
auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1];
auto n = c_shape.lens()[1];
auto rank = a_shape.lens().size();
std::array<char, 4> keys{'M', 'N', 'K', 'O'};
// config (m0, n0, k0, n1)
std::array<std::size_t, 4> config{
c_shape.lens()[rank - 2], b_shape.lens()[rank - 2], a_shape.lens().back(), c_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
std::array<std::size_t, 4> config{c_shape.lens()[rank - 2],
b_shape.lens()[rank - 2],
a_shape.lens().back(),
c_shape.lens().back()};
auto tuning_val =
v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
......@@ -220,8 +224,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(),
std::size_t{1},
std::multiplies<std::size_t>());
......
......@@ -935,122 +935,122 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"8",
"false",
"std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "64",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "2",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<16,16,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "128",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "4",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<8,32,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "64",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "2",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<16,16,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "128",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "4",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<8,32,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
{"ck::tensor_layout::gemm::RowMajor",
"ck::tensor_layout::gemm::ColumnMajor",
"ck::tensor_layout::gemm::RowMajor",
......
......@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_shape = get_shape_c<B>{};
constexpr const auto n = b_shape.lens[1];
constexpr const auto n = b_shape.lens[1];
constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1;
......@@ -85,8 +85,8 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_shape = get_shape_c<B1>{};
constexpr const auto k1 = b1_shape.lens[0];
constexpr const auto n1 = b1_shape.lens[1];
constexpr const auto k1 = b1_shape.lens[0];
constexpr const auto n1 = b1_shape.lens[1];
constexpr const auto sb1 = b1_shape.strides[0]; // row-major
constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1;
......
......@@ -50,15 +50,15 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
// b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto scale =
// mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto bias =
// mm->add_instruction(migraphx::make_op("add"), scale, zero); auto softmax =
// mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
size_t batch = 2;
migraphx::shape m1_shape{migraphx::shape::half_type, {batch, 384, 2304}};
migraphx::shape m2_shape{migraphx::shape::half_type, {batch, 12, 384, 384}};
......@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g);
g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g);
auto a = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
auto a = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
......
......@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function = -1
@contextlib.contextmanager
def tmp_file(dump=None):
tmp_name = None
......@@ -99,7 +100,6 @@ def parse_log(f):
config = json.loads(line)
ck_function = 1
yield config
def benchmark_log(f, n):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment