Commit 407acb7d authored by Paul's avatar Paul
Browse files

Jit contiguous

parent a27dd28c
...@@ -41,7 +41,7 @@ __global__ void kernel(${params}) ...@@ -41,7 +41,7 @@ __global__ void kernel(${params})
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -146,7 +146,14 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -146,7 +146,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
if (op.name() == "contiguous")
{
return replace(
compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", "[](auto x) { return x; }"}}));
}
else
{ {
assert(not ins->module_inputs().empty()); assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
...@@ -170,6 +177,7 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -170,6 +177,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return replace( return replace(
compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}})); compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}}));
} }
}
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -17,10 +17,17 @@ struct implicit_conversion_op ...@@ -17,10 +17,17 @@ struct implicit_conversion_op
template <index_int N, class U> template <index_int N, class U>
constexpr operator vec<U, N>() const constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{ {
static_assert(vec_size<T>() == N, "Vector mismatch size"); static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>); return __builtin_convertvector(x, vec<U, N>);
} }
}
template <class U> template <class U>
constexpr operator U() const constexpr operator U() const
......
...@@ -130,7 +130,7 @@ struct miopen_apply ...@@ -130,7 +130,7 @@ struct miopen_apply
add_generic_op("atan"); add_generic_op("atan");
add_generic_op("atanh"); add_generic_op("atanh");
add_generic_op("ceil"); add_generic_op("ceil");
add_generic_op("contiguous"); // add_generic_op("contiguous");
add_generic_op("cos"); add_generic_op("cos");
add_generic_op("cosh"); add_generic_op("cosh");
add_generic_op("div"); add_generic_op("div");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment