Commit 48b39e06 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent de99db23
...@@ -176,15 +176,12 @@ template <index_int N, class T, class... Ts> ...@@ -176,15 +176,12 @@ template <index_int N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
auto sx = get_shape(x); auto sx = get_shape(x);
auto lens = sx.lens(); auto lens = sx.lens();
lens.back() /= N; lens.back() /= N;
shape ssx{sx.type(), lens}; shape ssx{sx.type(), lens};
hip_visit_all_impl(ssx, hip_visit_all_impl(
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }), ssx, make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }), f, x, xs...);
f,
x,
xs...);
}; };
} }
......
...@@ -81,10 +81,7 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -81,10 +81,7 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
} }
template <index_int MaxBlockSize, class Input, class Output> template <index_int MaxBlockSize, class Input, class Output>
__device__ void layernorm(index idx, __device__ void layernorm(index idx, index_int relements, Input input, Output output)
index_int relements,
Input input,
Output output)
{ {
using value_type = decltype(input(idx.local)); using value_type = decltype(input(idx.local));
const auto relements_v = relements / vector_size<value_type>{}; const auto relements_v = relements / vector_size<value_type>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment