#include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { #ifndef MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC #if __AMDGCN_WAVEFRONT_SIZE == 32 #define MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC 1 #else #define MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC 0 #endif #endif template struct vector_type { }; template struct vector_type> { using type = T; }; template using vector_type_t = typename vector_type::type; template struct vector_size : std::integral_constant { }; template struct vector_size> : std::integral_constant { }; template __device__ auto vec_transform(T x, F f) { return f(x); } template __device__ auto vec_transform(vec x, F f) { vec y = x; // cppcheck-suppress useStlAlgorithm for(index_int k = 0; k < N; k++) y[k] = f(x[k]); return y; } template __device__ auto vec_reduce(T x, U, Op) { return x; } template __device__ auto vec_reduce(vec x, U init, Op op) { T r = init; for(index_int k = 0; k < N; k++) r = op(r, x[k]); return r; } template __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f) { auto r = block_reduce(idx, op, init, n, f); return vec_reduce(r, 0, op); } template __device__ void layernorm(index_int i, index idx, std::size_t block_size_div, index_int relements, Input input, Output output) { using value_type = decltype(input(idx.local)); const auto relements_v = relements / vector_size{}; const auto out_idx = fast_div(i, block_size_div); const auto base_idx = out_idx * relements_v; const auto input_idx = base_idx + idx.local; const bool in_range = idx.local < relements_v; auto mean = [&](auto z) { auto m = auto_block_reduce( idx, sum{}, value_type(0), relements_v, [=](auto) { return z; }) / value_type(relements); #if MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC __builtin_amdgcn_s_barrier(); #endif return m; }; // m = x - mean(x) value_type x = in_range ? input(input_idx) : 0; value_type m = x - mean(x); // mean(m ^ 2) + 1e-12 value_type r = mean(m * m) + value_type(1e-12); // m * rsqrt(mean(m ^ 2) + 1e-12) if(in_range) output(input_idx, m * vec_transform(r, &rsqrt)); } // m = x - mean(x) // m / sqrt(mean(m ^ 2) + 1e-12) template void layernorm_vec_impl(hipStream_t stream, index_int nelements, index_int relements, Input in, Output out, const argument& result, const Arguments&... args) { hip_vec_visit_all(result, args...)([&](auto output, auto... inputs) { const auto relements_v = relements / N; const std::size_t max_block_size = 256; const std::size_t block_size = compute_block_size(relements_v, max_block_size); const std::size_t block_size_div = encode_divisor(block_size); assert(relements_v <= block_size); gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { layernorm( i, idx, block_size_div, relements, [&](auto input_idx) { return in(inputs.data()[input_idx]...); }, [&](auto input_idx, auto x) { out(x, output.data()[input_idx], inputs.data()[input_idx]...); }); }); }); } template void layernorm_impl(hipStream_t stream, index_int nelements, index_int relements, Input in, Output out, const argument& result, const Arguments&... args) { hip_visit_all(result, args...)([&](auto output, auto... inputs) { const std::size_t max_block_size = 256; const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size_div = encode_divisor(block_size); assert(relements <= block_size); gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { layernorm( i, idx, block_size_div, relements, [&](auto input_idx) { return in(inputs.data()[input_idx]...); }, [&](auto input_idx, auto x) { out(x, output.data()[input_idx], inputs.data()[input_idx]...); }); }); }); } template auto layernorm_fusion(hipStream_t stream, const argument& result, const argument& arg1, const Arguments&... args) { return [=](auto input, auto output) { auto relements = arg1.get_shape().lens().back(); auto nelements = result.get_shape().elements() / relements; // auto output_shape = result.get_shape(); // auto reduce_output_lens(output_shape.lens()); // reduce_output_lens.back() = 1; if((relements % 4) == 0) layernorm_vec_impl<4>( stream, nelements, relements, input, output, result, arg1, args...); else if(relements < 256) layernorm_impl(stream, nelements, relements, input, output, result, arg1, args...); else MIGRAPHX_THROW("No kernel for layernorm"); }; } void triadd_layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) { layernorm_fusion(stream, result, arg1, arg2, arg3)( [](auto x, auto y, auto z) { return x + y + z; }, [](auto x, auto& y, auto...) { y = x; }); } void layernorm(hipStream_t stream, const argument& result, const argument& arg1) { layernorm_fusion(stream, result, arg1)([](auto x) { return x; }, [](auto x, auto& y, auto) { y = x; }); } } // namespace device } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx