Commit 4064d804 authored by Paul's avatar Paul
Browse files

Merge

parents 4001eef7 dc296a73
...@@ -754,10 +754,16 @@ auto skip_broadcasts(Ms... ms) ...@@ -754,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...); return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
} }
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T> template <class T>
inline auto has_value(T x, float tolerance = 1e-6) inline auto has_value(T x, float tolerance = 1e-6)
{ {
return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) { return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal") if(ins->name() != "@literal")
return false; return false;
auto l = ins->get_literal(); auto l = ins->get_literal();
......
...@@ -30,7 +30,7 @@ namespace migraphx { ...@@ -30,7 +30,7 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void ${kernel}(${params})
{ {
auto idx = make_index(); auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args}); pointwise(idx, ${transformers})(${lambda}, ${args});
...@@ -42,6 +42,18 @@ __global__ void kernel(${params}) ...@@ -42,6 +42,18 @@ __global__ void kernel(${params})
)__migraphx__"; )__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
...@@ -53,17 +65,6 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -53,17 +65,6 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else else
return 1; return 1;
} }
static std::size_t compute_local(gen::vectorize v, const std::vector<shape>& inputs)
{
const std::size_t max_local = 1024;
if(std::none_of(inputs.begin(), inputs.end(), [&](auto s) { return s.transposed(); }))
return max_local;
if(std::any_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.broadcasted() or s.strides()[v.axis] != 1;
}))
return max_local;
return inputs.front().lens()[v.axis] / v.size;
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
...@@ -72,16 +73,18 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -72,16 +73,18 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec_size = vectorize_elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs); auto preloads = preload(axis, options.virtual_inputs);
options.set_launch_params( auto is_preloading =
v, std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
compute_global_for(ctx, options.kernel_name = v.get("kernel", "kernel");
options.output.elements() / vec.size, options.set_launch_params(v,
oversubscribe_if(not preloads.is_preloading())), compute_global_for(ctx,
compute_local(vec, options.virtual_inputs)); options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(preloads, vec)},
...@@ -94,7 +97,9 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -94,7 +97,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
if(op.name() == "contiguous") if(op.name() == "contiguous")
{ {
return replace(compile_op( return replace(compile_op(
ctx, to_shapes(ins->inputs()), {{"lambda", "[](auto x) { return x; }"}})); ctx,
to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
} }
else else
{ {
...@@ -118,8 +123,13 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -118,8 +123,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto name = g.create_function( auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op( return replace(compile_op(
ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}})); ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment