Commit b7e80b6e authored by Paul's avatar Paul
Browse files

Format

parent e34cb7c1
...@@ -38,7 +38,7 @@ struct ck_gemm ...@@ -38,7 +38,7 @@ struct ck_gemm
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input:inputs) for(const auto& input : inputs)
check_gemm_shape(input); check_gemm_shape(input);
return op.compute_shape({a, b}); return op.compute_shape({a, b});
} }
...@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2) if(a.lens().size() > 2 or b.lens().size() > 2)
return false; return false;
if (a.lens()[1] >= 2048) if(a.lens()[1] >= 2048)
return false; return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0); b.lens()[1] % 8 == 0);
...@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct find_ck_gemm struct find_ck_gemm
{ {
// Find a gemm followed by a pointwise operation. // Find a gemm followed by a pointwise operation.
auto matcher() const { auto matcher() const
auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm"))); {
auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
} }
...@@ -81,13 +83,14 @@ struct find_ck_gemm ...@@ -81,13 +83,14 @@ struct find_ck_gemm
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if (gemm_idx != 0) if(gemm_idx != 0)
{ {
// std::swap(inputs[0], inputs[gemm_idx]); // std::swap(inputs[0], inputs[gemm_idx]);
auto first_param = pm->get_parameter(names[0]); auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]); auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape()); auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape()); auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param); pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param); pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param); pm->remove_instruction(first_param);
......
...@@ -139,7 +139,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -139,7 +139,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return shape::cpp_type(s.type()); return shape::cpp_type(s.type());
} }
template<class Iterator, class F> template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f) static std::string ck_tuple(Iterator start, Iterator last, F f)
{ {
std::vector<std::string> s; std::vector<std::string> s;
...@@ -165,17 +165,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -165,17 +165,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
}); });
assert(inputs.size() < 4 or v.contains("post")); assert(inputs.size() < 4 or v.contains("post"));
if (v.contains("post")) if(v.contains("post"))
{ {
assert(instance[2] == "ck::Tuple<>"); assert(instance[2] == "ck::Tuple<>");
instance[2] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_layout); instance[2] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);
assert(instance[8] == "ck::Tuple<>"); assert(instance[8] == "ck::Tuple<>");
instance[8] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_type); instance[8] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type);
assert(instance[12] == "ck_passthrough"); assert(instance[12] == "ck_passthrough");
instance[12] = v.at("post").to<std::string>(); instance[12] = v.at("post").to<std::string>();
} }
hip_compile_options options; hip_compile_options options;
auto block_size = get_block_size(instance); auto block_size = get_block_size(instance);
auto grid_size = get_grid_size(instance, m, n); auto grid_size = get_grid_size(instance, m, n);
...@@ -185,13 +184,12 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -185,13 +184,12 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = v.get("kernel", "ck_gemm_kernel"); options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
auto src = interpolate_string(ck_gemm_kernel, { auto src = interpolate_string(ck_gemm_kernel,
{"instance", join_strings(instance, ",")}, {{"instance", join_strings(instance, ",")},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name} {"kernel", options.kernel_name}});
});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -203,7 +201,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -203,7 +201,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);"; v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm>"; v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
} }
......
...@@ -50,13 +50,13 @@ constexpr bool is_row_major() ...@@ -50,13 +50,13 @@ constexpr bool is_row_major()
template <class T> template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type; using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class T> template <class T>
constexpr auto to_ck_pointer(T* x) constexpr auto to_ck_pointer(T* x)
{ {
return static_cast<to_ck_type<T>*>(x); return static_cast<to_ck_type<T>*>(x);
} }
template<class T> template <class T>
constexpr auto to_ck_const_pointer(const T* x) constexpr auto to_ck_const_pointer(const T* x)
{ {
return static_cast<const to_ck_type<T>*>(x); return static_cast<const to_ck_type<T>*>(x);
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
#define MIGRAPHX_LIFT_CLASS(name, ...) \ #define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \ struct name \
{ \ { \
template<class... PrivateLiftTs> \ template <class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \ constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \ (__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment