Commit b7e80b6e authored by Paul's avatar Paul
Browse files

Format

parent e34cb7c1
......@@ -38,7 +38,7 @@ struct ck_gemm
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input:inputs)
for(const auto& input : inputs)
check_gemm_shape(input);
return op.compute_shape({a, b});
}
......@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
if (a.lens()[1] >= 2048)
if(a.lens()[1] >= 2048)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
......@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct find_ck_gemm
{
// Find a gemm followed by a pointwise operation.
auto matcher() const {
auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
auto matcher() const
{
auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
......@@ -77,17 +79,18 @@ struct find_ck_gemm
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if (gemm_idx != 0)
if(gemm_idx != 0)
{
// std::swap(inputs[0], inputs[gemm_idx]);
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
......
......@@ -139,7 +139,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return shape::cpp_type(s.type());
}
template<class Iterator, class F>
template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
......@@ -158,24 +158,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto m = c_shape.lens().front();
auto n = c_shape.lens().back();
auto i = v.get("tuning_val", get_tuning_for(inputs));
auto i = v.get("tuning_val", get_tuning_for(inputs));
auto instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
});
assert(inputs.size() < 4 or v.contains("post"));
if (v.contains("post"))
if(v.contains("post"))
{
assert(instance[2] == "ck::Tuple<>");
instance[2] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_layout);
instance[2] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);
assert(instance[8] == "ck::Tuple<>");
instance[8] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_type);
instance[8] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type);
assert(instance[12] == "ck_passthrough");
instance[12] = v.at("post").to<std::string>();
}
hip_compile_options options;
auto block_size = get_block_size(instance);
auto grid_size = get_grid_size(instance, m, n);
......@@ -185,28 +184,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs;
auto src = interpolate_string(ck_gemm_kernel, {
{"instance", join_strings(instance, ",")},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}
});
auto src = interpolate_string(ck_gemm_kernel,
{{"instance", join_strings(instance, ",")},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["kernel"] = "ck_gemm_kernel";
auto v = op.to_value();
v["kernel"] = "ck_gemm_kernel";
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
}
auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
......
......@@ -50,13 +50,13 @@ constexpr bool is_row_major()
template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class T>
template <class T>
constexpr auto to_ck_pointer(T* x)
{
return static_cast<to_ck_type<T>*>(x);
}
template<class T>
template <class T>
constexpr auto to_ck_const_pointer(const T* x)
{
return static_cast<const to_ck_type<T>*>(x);
......
......@@ -36,12 +36,12 @@
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template<class... PrivateLiftTs> \
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template <class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
}
namespace migraphx {
......
......@@ -33,9 +33,9 @@ struct gemm_add_relu : verify_program<gemm_add_relu>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {16, 8}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {8, 32}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {16, 32}});
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {16, 8}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {8, 32}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {16, 32}});
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), dot, c);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment