Commit e67c8e1e authored by Paul's avatar Paul
Browse files

Format

parent c81b53c5
......@@ -100,7 +100,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)},
{"eps", to_string(eps)}});
return compile_hip_code_object(src, options);
}
......
......@@ -171,13 +171,13 @@ struct inner_storage_tag
{
};
template<class R, class F>
template <class R, class F>
struct storage_access : F
{
using type = R;
};
template<class R, class F>
template <class R, class F>
constexpr storage_access<R, F> make_storage_access(F f)
{
return {{f}};
......@@ -207,10 +207,10 @@ constexpr auto compute_reduce_axis()
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
template<class Derived>
template <class Derived>
struct reducer_base
{
template<class T>
template <class T>
__device__ auto make_inner_slice(T x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
......@@ -220,21 +220,21 @@ struct reducer_base
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
});
}
}
template<class T, class... Ts>
template <class T, class... Ts>
constexpr auto get_size(T&& x, [[maybe_unused]] Ts&&... xs) const
{
MIGRAPHX_ASSERT(get_size(x) == get_size(xs...));
return get_size(x);
}
template<class T, class... Ts>
template <class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
......@@ -244,20 +244,18 @@ struct reducer_base
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
auto t = derived.slice(x);
return t.size();
}
}
template<class F>
template <class F>
__device__ auto inner_sliced(F f) const
{
return [=](auto&&... xs) {
return f(get_size(xs...), make_inner_slice(xs)...);
};
return [=](auto&&... xs) { return f(get_size(xs...), make_inner_slice(xs)...); };
}
template<class T>
template <class T>
static __device__ typename T::type& decl_inner_storage(const T&);
template <class F>
......@@ -265,7 +263,7 @@ struct reducer_base
{
return this->inner_sliced([=](auto n, auto&&... xs) {
using result_type = decltype(f(decl_inner_storage(xs)...));
auto&& derived = static_cast<const Derived&>(*this);
auto&& derived = static_cast<const Derived&>(*this);
if constexpr(is_void<result_type>{})
{
derived.inner_void_impl(f, n, xs...);
......@@ -283,7 +281,6 @@ struct reducer_base
return this->inner_sliced([=](auto n, auto&&... xs) {
auto&& derived = static_cast<const Derived&>(*this);
return derived.reduce_impl(op, init, read, n, xs...);
});
}
......@@ -302,7 +299,7 @@ struct reducer_base
template <class Input>
constexpr auto elements() const
{
auto&& derived = static_cast<const Derived&>(*this);
auto&& derived = static_cast<const Derived&>(*this);
using reduce_type = decltype(derived.slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
......@@ -326,7 +323,7 @@ struct block
{
using type = T;
array<T, N> arr;
constexpr Size rsize() const {return {};}
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto& operator()(U, V d) const
{
......
......@@ -141,23 +141,23 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template<class T>
template <class T>
struct remove_cv
{
using type = T;
};
template<class T>
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template<class T>
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template<class T>
template <class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T>
......@@ -187,8 +187,10 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
template<class T>
struct is_void : is_same<void, remove_cv_t<T>> {};
template <class T>
struct is_void : is_same<void, remove_cv_t<T>>
{
};
template <class... Ts>
struct common_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment