Unverified Commit 1b7af545 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Fix conversion issue in layernorm fusion (#1483) (#1493)

parent fe19455f
......@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0);
}
argument
......@@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1);
}
argument compute(context& ctx, const shape&, std::vector<argument> args) const
......
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
......
......@@ -33,38 +33,6 @@
namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
......
......@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
}
}
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
......@@ -51,17 +51,20 @@ struct layernorm_base
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
auto t = s.type();
if(not mods.empty())
t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
return {t, s.lens()};
}
else
{
return s.with_lens(s.lens());
return s.with_lens(t, s.lens());
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment