Commit fdd86cc4 authored by Paul's avatar Paul
Browse files

Format

parent db816c6f
...@@ -108,7 +108,7 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -108,7 +108,7 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& name) cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& name)
{ {
params.push_back({name, "T"+name}); params.push_back({name, "T" + name});
tparams.push_back("class T" + name); tparams.push_back("class T" + name);
return *this; return *this;
} }
...@@ -189,7 +189,8 @@ std::string cpp_generator::generate_point_op(const operation& op, ...@@ -189,7 +189,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
std::string cpp_generator::str() const { return impl->fs.str(); } std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m, const generate_module_callback& g) cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g)
{ {
function f; function f;
auto name = transform_string(m.name(), [](char c) { auto name = transform_string(m.name(), [](char c) {
...@@ -211,13 +212,14 @@ cpp_generator::function cpp_generator::generate_module(const module& m, const ge ...@@ -211,13 +212,14 @@ cpp_generator::function cpp_generator::generate_module(const module& m, const ge
return f; return f;
} }
std::vector<std::string> cpp_generator::to_args(const std::vector<instruction_ref>& inputs, const std::unordered_map<instruction_ref, std::string>& names) std::vector<std::string>
cpp_generator::to_args(const std::vector<instruction_ref>& inputs,
const std::unordered_map<instruction_ref, std::string>& names)
{ {
std::vector<std::string> args; std::vector<std::string> args;
std::transform(inputs.begin(), std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) {
inputs.end(), return names.at(i);
std::back_inserter(args), });
[&](auto i) { return names.at(i); });
return args; return args;
} }
......
...@@ -106,7 +106,9 @@ struct cpp_generator ...@@ -106,7 +106,9 @@ struct cpp_generator
std::string create_function(const function& f); std::string create_function(const function& f);
static std::vector<std::string> to_args(const std::vector<instruction_ref>& inputs, const std::unordered_map<instruction_ref, std::string>& names); static std::vector<std::string>
to_args(const std::vector<instruction_ref>& inputs,
const std::unordered_map<instruction_ref, std::string>& names);
private: private:
std::unique_ptr<cpp_generator_impl> impl; std::unique_ptr<cpp_generator_impl> impl;
......
...@@ -207,9 +207,9 @@ struct reduce_op ...@@ -207,9 +207,9 @@ struct reduce_op
{ {
std::string input; std::string input;
std::string reduction = ""; std::string reduction = "";
std::string init = "0"; std::string init = "0";
std::string read = "op::id{}"; std::string read = "op::id{}";
std::string write = "op::id{}"; std::string write = "op::id{}";
std::string str() const std::string str() const
{ {
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))"; return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
...@@ -225,7 +225,7 @@ struct reduce_op ...@@ -225,7 +225,7 @@ struct reduce_op
{ {
auto reduce_elements = get_reduce_elements(ins->inputs()); auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_type = ins->inputs().front()->get_shape().type();
r.reduction = "op::sum{}"; r.reduction = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384) if(reduce_type == shape::half_type and reduce_elements > 16384)
...@@ -267,17 +267,18 @@ std::string generate_reduce(const module& rm, const std::string& name) ...@@ -267,17 +267,18 @@ std::string generate_reduce(const module& rm, const std::string& name)
module m = rm; module m = rm;
cpp_generator g; cpp_generator g;
std::size_t i = 0; std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) { auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if (contains(ins->name(), "reduce")) if(contains(ins->name(), "reduce"))
{ {
return reduce_op::generate(ins, names.at(ins->inputs().front())); return reduce_op::generate(ins, names.at(ins->inputs().front()));
} }
else if (ins->name() == "pointwise") else if(ins->name() == "pointwise")
{ {
auto pointwise_name = "pointwise" + std::to_string(i); auto pointwise_name = "pointwise" + std::to_string(i);
i++; i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name); generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
return pointwise_name + "(" + join_strings(cpp_generator::to_args(ins->inputs(), names), ", ") + ")"; return pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), names), ", ") + ")";
} }
MIGRAPHX_THROW("Unknown operator: " + ins->name()); MIGRAPHX_THROW("Unknown operator: " + ins->name());
}); });
...@@ -294,7 +295,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -294,7 +295,7 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
if (ins.name() == "pointwise") if(ins.name() == "pointwise")
{ {
auto names = get_op_names(*ins.module_inputs().front()); auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end()); result.insert(result.end(), names.begin(), names.end());
......
...@@ -84,10 +84,10 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& ...@@ -84,10 +84,10 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return reduce_lens; return reduce_lens;
} }
template<class T> template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes) static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{ {
auto lens = s.lens(); auto lens = s.lens();
for(const auto& axis : axes) for(const auto& axis : axes)
lens[axis] = 1; lens[axis] = 1;
return shape{s.type(), lens}; return shape{s.type(), lens};
...@@ -112,16 +112,14 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs) ...@@ -112,16 +112,14 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
struct fused_reduce_compiler : compiler<fused_reduce_compiler> struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const { return {"fused_reduce"}; }
{
return {"fused_reduce"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto virtual_inputs = inputs; auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), v.at("axes").to_vector<std::size_t>())); virtual_inputs.push_back(
virtual_inputs = reduce_dims(virtual_inputs); get_reduced_shape(inputs.front(), v.at("axes").to_vector<std::size_t>()));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduced_shape = virtual_inputs.back(); auto reduced_shape = virtual_inputs.back();
virtual_inputs.pop_back(); virtual_inputs.pop_back();
...@@ -155,15 +153,16 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -155,15 +153,16 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
} }
options.kernel_name = v.get("kernel", "reduce_kernel"); options.kernel_name = v.get("kernel", "reduce_kernel");
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src =
{{"kernel", options.kernel_name}, interpolate_string(simple_reduce_kernel,
{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"args", enum_params(inputs.size(), "private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"algo", algo}, {"args", enum_params(inputs.size(), "private_p")},
{"reduced", "decltype(" + generate_make_shape(reduced_shape) + ")"}, {"algo", algo},
{"lambda", v.at("lambda").to<std::string>()}, {"reduced", "decltype(" + generate_make_shape(reduced_shape) + ")"},
{"transformers", make_transformer_args(vec)}, {"lambda", v.at("lambda").to<std::string>()},
{"preamble", v.get("preamble", std::string{})}}); {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -171,16 +170,12 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -171,16 +170,12 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
assert(not ins->module_inputs().empty()); assert(not ins->module_inputs().empty());
auto v = op.to_value(); auto v = op.to_value();
auto* rm = ins->module_inputs().front(); auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op"); v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)"; v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel"; v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace( return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
compile_op(ctx,
to_shapes(ins->inputs()),
v));
} }
}; };
} // namespace gpu } // namespace gpu
......
...@@ -195,13 +195,11 @@ constexpr auto compose(Fs... fs) ...@@ -195,13 +195,11 @@ constexpr auto compose(Fs... fs)
})(fs...); })(fs...);
} }
template<class F> template <class F>
constexpr auto partial(F f) constexpr auto partial(F f)
{ {
return [=](auto... xs) { return [=](auto... xs) {
return [=](auto&&... ys) { return [=](auto&&... ys) { return f(xs..., static_cast<decltype(ys)>(ys)...); };
return f(xs..., static_cast<decltype(ys)>(ys)...);
};
}; };
} }
......
...@@ -471,8 +471,7 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu ...@@ -471,8 +471,7 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
} }
template <class Algo, class Reduced, class Output, class F> template <class Algo, class Reduced, class Output, class F>
__device__ void __device__ void fused_reduce(Output output, F f)
fused_reduce(Output output, F f)
{ {
Algo::template run<Reduced>([&](auto out_idx, auto r) { Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r); auto result = f(r);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment