Commit 5616894f authored by Paul's avatar Paul
Browse files

Merge

parents 835cc1e2 4a312201
......@@ -754,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) {
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal")
return false;
auto l = ins->get_literal();
......
......@@ -27,7 +27,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
__global__ void ${kernel}(${params})
{
auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
......@@ -39,6 +39,18 @@ __global__ void kernel(${params})
)__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
......@@ -131,12 +143,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto preloads = preload(axis, options.virtual_inputs);
auto is_preloading =
std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"vec_size", std::to_string(vec_size)},
......@@ -151,7 +165,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
if(op.name() == "contiguous")
{
return replace(compile_op(
ctx, to_shapes(ins->inputs()), {{"lambda", "[](auto x) { return x; }"}}));
ctx, to_shapes(ins->inputs()), {{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
}
else
{
......@@ -169,14 +183,18 @@ struct pointwise_compiler : compiler<pointwise_compiler>
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
return replace(compile_op(
ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}}));
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment