Commit b7e80b6e authored by Paul's avatar Paul
Browse files

Format

parent e34cb7c1
......@@ -38,7 +38,7 @@ struct ck_gemm
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input:inputs)
for(const auto& input : inputs)
check_gemm_shape(input);
return op.compute_shape({a, b});
}
......@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
if (a.lens()[1] >= 2048)
if(a.lens()[1] >= 2048)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
......@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct find_ck_gemm
{
// Find a gemm followed by a pointwise operation.
auto matcher() const {
auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
auto matcher() const
{
auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
......@@ -81,13 +83,14 @@ struct find_ck_gemm
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if (gemm_idx != 0)
if(gemm_idx != 0)
{
// std::swap(inputs[0], inputs[gemm_idx]);
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
......
......@@ -139,7 +139,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return shape::cpp_type(s.type());
}
template<class Iterator, class F>
template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
......@@ -165,17 +165,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
});
assert(inputs.size() < 4 or v.contains("post"));
if (v.contains("post"))
if(v.contains("post"))
{
assert(instance[2] == "ck::Tuple<>");
instance[2] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_layout);
instance[2] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);
assert(instance[8] == "ck::Tuple<>");
instance[8] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_type);
instance[8] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type);
assert(instance[12] == "ck_passthrough");
instance[12] = v.at("post").to<std::string>();
}
hip_compile_options options;
auto block_size = get_block_size(instance);
auto grid_size = get_grid_size(instance, m, n);
......@@ -185,13 +184,12 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs;
auto src = interpolate_string(ck_gemm_kernel, {
{"instance", join_strings(instance, ",")},
auto src = interpolate_string(ck_gemm_kernel,
{{"instance", join_strings(instance, ",")},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}
});
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
......@@ -203,7 +201,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
......
......@@ -50,13 +50,13 @@ constexpr bool is_row_major()
template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class T>
template <class T>
constexpr auto to_ck_pointer(T* x)
{
return static_cast<to_ck_type<T>*>(x);
}
template<class T>
template <class T>
constexpr auto to_ck_const_pointer(const T* x)
{
return static_cast<const to_ck_type<T>*>(x);
......
......@@ -39,7 +39,7 @@
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template<class... PrivateLiftTs> \
template <class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment