Commit 31065c7d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f 6acbd4e4
...@@ -21,63 +21,69 @@ ...@@ -21,63 +21,69 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/device/gelu.hpp> #ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/gpu/device/nary.hpp> #define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/gpu/device/types.hpp> #include <migraphx/kernels/reduce.hpp>
#include <cmath> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
// x * 0.5 * (1.0 + erf(x / sqrt(2.0))) template <class T, index_int N, class Op>
template <class T> constexpr auto vec_reduce(const array<T, N>& a, Op op)
auto gelu_fn(T x) __device__
{ {
return x * 0.5 * (1 + ::erf(x * M_SQRT1_2)); return a.apply([&](auto x) { return vec_reduce(x, op); });
} }
// 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * pow(x, 3)))) template <index_int Axis,
template <class T> class F,
auto gelu_fn_new(T x) __device__ class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
return 0.5 * x * (1 + tanh(sqrt(M_2_PI) * (x + 0.044715 * x * x * x))); using reduce_output = reduce::with_axis<Input1, Axis>;
} reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto means =
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2);
void gelu(hipStream_t stream, const argument& result, const argument& arg) auto mean_x = means[0];
{ auto mean_x2 = means[1];
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn(to_hip_type(x)); }); auto variance = mean_x2 - (mean_x * mean_x);
} value_type eps_val = eps; // implicit conversion for eps
void gelu_new(hipStream_t stream, const argument& result, const argument& arg) r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
{ auto x = op(x1, x2);
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); }); auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...);
});
} }
void add_gelu(hipStream_t stream, template <index_int Axis, class F, class Output, class Input, class... Inputs>
const argument& result, __device__ void layernorm(F compute, float eps, Output output, Input input, Inputs... inputs)
const argument& arg1,
const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { generic_binary_layernorm<Axis>(
auto sum = to_hip_type(x + y); compute, [](auto x, auto) { return x; }, eps, output, input, input, inputs...);
return gelu_fn(sum);
});
} }
void add_gelu_new(hipStream_t stream, template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
const argument& result, __device__ void
const argument& arg1, add_layernorm(F compute, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { generic_binary_layernorm<Axis>(
auto sum = to_hip_type(x + y); compute, [](auto x1, auto x2) { return x1 + x2; }, eps, output, input1, input2, inputs...);
return gelu_fn(sum);
});
} }
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor) ...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan) MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin) MIGRAPHX_DEVICE_MATH(sin, ::sin)
...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh) ...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan) MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh) MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH(fmod, ::fmod)
// Float overloads // Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) ...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions // Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) ...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan) MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh) ...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf) MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(fmod)
MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min) MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt) MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin) MIGRAPHX_DEVICE_MATH_VEC(sin)
......
...@@ -90,7 +90,7 @@ struct lowest ...@@ -90,7 +90,7 @@ struct lowest
template <class T> template <class T>
constexpr operator T() const constexpr operator T() const
{ {
return numeric_lowest<T>(); return numeric_lowest<vec_type<T>>();
} }
}; };
...@@ -99,7 +99,7 @@ struct highest ...@@ -99,7 +99,7 @@ struct highest
template <class T> template <class T>
constexpr operator T() const constexpr operator T() const
{ {
return numeric_max<T>(); return numeric_max<vec_type<T>>();
} }
}; };
} // namespace migraphx } // namespace migraphx
......
...@@ -21,44 +21,43 @@ ...@@ -21,44 +21,43 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP #define MIGRAPHX_GUARD_KERNELS_PAD_HPP
#include <migraphx/gpu/oper.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/gpu/device/where.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_where : ternary_device<hip_where, device::where> template <class Offsets, class Input, class Output, class PadVal>
__device__ void pad(const index& idx,
const Offsets& offsets,
const Input& input,
Output& output,
const PadVal& pad_val)
{ {
shape compute_shape(const std::vector<shape>& inputs) const auto output_shape = output.get_shape();
{ idx.global_stride(output_shape.elements(), [&](auto i) {
check_shapes{inputs, *this}.has(4).same_dims(); // 1. get current multi-index for output
auto s1 = inputs.at(1); // 2. get the size of the input to determine input boundaries
auto s2 = inputs.at(2); // 3. compute the corresponding multi-index for input by accounting for offsets
if(s1 == s2 and s1.packed()) // 4. if current multi-index is within offsets or input's new multi-index is out of bounds,
{ // use pad value instead of input's value
return s1; auto multi = output_shape.multi(i);
} auto input_bounds = input.get_shape().lens;
else if(s1.packed() != s2.packed()) auto input_idx = multi - offsets;
{ auto range_multi = range(multi.size());
return s1.packed() ? s1 : s2;
} if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
else if(s1.broadcasted() != s2.broadcasted()) return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
{ }))
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens()); output[multi] = pad_val;
}
else else
{ output[multi] = input[input_idx];
return {s1.type(), s1.lens()}; });
} }
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -21,22 +21,29 @@ ...@@ -21,22 +21,29 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_SQRT_HPP #ifndef MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQRT_HPP #define MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#include <migraphx/gpu/oper.hpp> #include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/gpu/device/sqrt.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqrt : unary_device<hip_sqrt, device::sqrt> template <class Iterator>
struct iterator_range
{ {
Iterator start;
Iterator last;
constexpr Iterator begin() const { return start; }
constexpr Iterator end() const { return last; }
}; };
} // namespace gpu constexpr iterator_range<iota_iterator> range(diff_int start, diff_int last)
} // namespace MIGRAPHX_INLINE_NS {
} // namespace migraphx return {{start, {}}, {last, {}}};
}
constexpr iterator_range<iota_iterator> range(diff_int last) { return range(0, last); }
#endif } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_RANGES_HPP
...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max) ...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min) MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul) MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return y; return y;
} }
#else #else
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x; buffer[idx.local] = x;
...@@ -196,17 +197,14 @@ struct block ...@@ -196,17 +197,14 @@ struct block
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
op, return vec_reduce(read(x[j], xs[j]...), op);
init, });
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
...@@ -220,10 +218,22 @@ struct block ...@@ -220,10 +218,22 @@ struct block
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
}; };
template <class Slicer> template <class Slicer>
...@@ -250,11 +260,11 @@ struct lane ...@@ -250,11 +260,11 @@ struct lane
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
using type = typename decltype(x)::type; using type = typename decltype(x)::type;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
...@@ -274,13 +284,20 @@ struct lane ...@@ -274,13 +284,20 @@ struct lane
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(x[j], xs[j]...);
} }
}); });
} }
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slice(Input{}));
return get_shape_c<reduce_type>{}.elements();
}
}; };
template <class Slicer> template <class Slicer>
......
...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output> ...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); #ifdef MIGRAPHX_USE_FAST_SOFTMAX
auto batch_sum = const auto c = vec_at(r.slice(input)[0], 0);
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); #else
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
input); #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c));
})(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
}); });
} }
......
...@@ -192,9 +192,13 @@ struct common_type<T, U, Us...> ...@@ -192,9 +192,13 @@ struct common_type<T, U, Us...>
template <class... Ts> template <class... Ts>
using common_type_t = typename common_type<Ts...>::type; using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; } constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; }
template <class T> template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{})>
constexpr T numeric_max() constexpr T numeric_max()
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
...@@ -230,8 +234,6 @@ constexpr T numeric_lowest() ...@@ -230,8 +234,6 @@ constexpr T numeric_lowest()
} }
} }
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -175,7 +175,7 @@ template <class T, class Op> ...@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr auto vec_reduce(T x, Op op) constexpr auto vec_reduce(T x, Op op)
{ {
if constexpr(vec_size<T>() < 2) if constexpr(vec_size<T>() < 2)
return x; return vec_type<T>{x};
else else
{ {
vec_type<T> result = x[0]; vec_type<T> result = x[0];
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_leaky_relu::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_leaky_relu::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
void miopen_leaky_relu::finalize(context&, const shape&, const std::vector<shape>&)
{
ad = make_leaky_relu(op.alpha);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -26,43 +26,24 @@ ...@@ -26,43 +26,24 @@
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/logical_and.hpp>
#include <migraphx/gpu/logical_or.hpp>
#include <migraphx/gpu/logical_xor.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -99,86 +80,26 @@ struct miopen_apply ...@@ -99,86 +80,26 @@ struct miopen_apply
(void)i; (void)i;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 auto& ctx = get_context();
auto& ctx = get_context(); int8_x4_format = get_int8_x4_format(ctx);
const auto device_name = trim(split_string(get_device_name(), ':').front()); compute_fp32 = get_compute_fp32_flag();
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("acos");
add_generic_op("acosh");
add_generic_op("add");
add_generic_op("asin");
add_generic_op("asinh");
add_generic_op("atan");
add_generic_op("atanh");
add_generic_op("ceil");
add_generic_op("contiguous"); add_generic_op("contiguous");
add_generic_op("cos");
add_generic_op("cosh");
add_generic_op("div");
add_generic_op("equal");
add_generic_op("erf");
add_generic_op("exp");
add_generic_op("floor");
add_generic_op("greater");
add_generic_op("less");
add_generic_op("log");
add_generic_op("logical_and");
add_generic_op("logical_or");
add_generic_op("logical_xor");
add_generic_op("max");
add_generic_op("min");
add_generic_op("mul");
add_generic_op("not");
add_generic_op("pow");
add_generic_op("prelu");
add_generic_op("recip");
add_generic_op("relu");
add_generic_op("round");
add_generic_op("rsqrt");
add_generic_op("sigmoid");
add_generic_op("sign");
add_generic_op("sin");
add_generic_op("sinh");
add_generic_op("sqdiff");
add_generic_op("sqrt");
add_generic_op("sub");
add_generic_op("tan");
add_generic_op("tanh");
add_generic_op("where");
add_extend_op("abs");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("clip");
add_extend_op("concat");
add_extend_op("convert");
add_extend_op("elu");
add_extend_op("gather"); add_extend_op("gather");
add_extend_op("leaky_relu");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
add_extend_op("lrn"); add_extend_op("lrn");
add_extend_op("multinomial"); add_extend_op("multinomial");
add_extend_op("nonzero"); add_extend_op("nonzero");
add_extend_op("pad");
add_extend_op("pooling"); add_extend_op("pooling");
add_extend_op("prefix_scan_sum"); add_extend_op("prefix_scan_sum");
add_extend_op("reverse"); add_extend_op("reverse");
...@@ -188,16 +109,15 @@ struct miopen_apply ...@@ -188,16 +109,15 @@ struct miopen_apply
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("topk"); add_extend_op("topk");
add_batch_norm_inference_op(); add_convolution_op<op::convolution>("convolution");
add_convolution_op(); add_convolution_op<op::deconvolution>("deconvolution");
add_deconvolution_op(); add_convolution_op<op::quant_convolution>("quant_convolution");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_if_op(); add_if_op();
add_loop_op(); add_loop_op();
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_quant_convolution_op();
} }
void copy_params() const void copy_params() const
...@@ -246,7 +166,8 @@ struct miopen_apply ...@@ -246,7 +166,8 @@ struct miopen_apply
init(); init();
for(auto it = mod->begin(); it != mod->end(); it++) for(auto it = mod->begin(); it != mod->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
auto attrs = it->get_operator().attributes();
if(apply_map.count(it->name()) > 0) if(apply_map.count(it->name()) > 0)
{ {
check_shape(s, apply_map.at(it->name())(it)); check_shape(s, apply_map.at(it->name())(it));
...@@ -255,11 +176,37 @@ struct miopen_apply ...@@ -255,11 +176,37 @@ struct miopen_apply
{ {
check_shape(s, insert_precompile_op(it)); check_shape(s, insert_precompile_op(it));
} }
else if(attrs.contains("target"))
{
check_shape(s, insert_custom_op(it, attrs));
}
} }
copy_params(); copy_params();
} }
instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const
{
const auto& custom_op = ins->get_operator();
if(attrs.at("target") == "cpu")
{
auto s = ins->get_shape();
std::vector<instruction_ref> cpu_inputs;
auto inputs = ins->inputs();
auto output = inputs.back();
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, custom_op, cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
}
return ins;
}
instruction_ref insert_precompile_op(instruction_ref ins) const instruction_ref insert_precompile_op(instruction_ref ins) const
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -278,38 +225,6 @@ struct miopen_apply ...@@ -278,38 +225,6 @@ struct miopen_apply
return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}}));
} }
void add_convolution_op()
{
apply_map.emplace("convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
void add_deconvolution_op()
{
apply_map.emplace("deconvolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::deconvolution>(ins->get_operator());
auto conv = miopen_deconvolution{op, make_deconv(op)};
auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
template <typename Op> template <typename Op>
void add_gemm_op(const std::string& name) void add_gemm_op(const std::string& name)
{ {
...@@ -323,31 +238,33 @@ struct miopen_apply ...@@ -323,31 +238,33 @@ struct miopen_apply
}); });
} }
void add_quant_convolution_op() template <typename Op>
void add_convolution_op(const std::string& name)
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); operation conv =
shape ws; miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), int8_x4_format};
miopen_quant_convolution conv; migraphx::context ctx = get_context();
auto compile_quant_conv_with_format = [&](bool format) { size_t ws_bytes = 0;
conv = miopen_quant_convolution{op, format, make_conv(op)}; auto compile_conv_with_format = [&](bool format) {
ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs())); conv = miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), format};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
ws_bytes = ws.get("workspace", 0);
}; };
try try
{ { // for the regular convolution and deconvolution, this try would always succeed
compile_quant_conv_with_format(int8_x4_format); compile_conv_with_format(int8_x4_format);
} }
catch(migraphx::exception&) catch(migraphx::exception&)
{ {
// In case no solver supports the default format, retry using the other format. // In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(!int8_x4_format); compile_conv_with_format(not int8_x4_format);
} }
auto args = ins->inputs(); auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
auto workspace = insert_allocation(ins, shape{shape::int8_type, {ws_bytes}});
return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output); return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output);
}); });
} }
...@@ -382,43 +299,6 @@ struct miopen_apply ...@@ -382,43 +299,6 @@ struct miopen_apply
}); });
} }
void add_batch_norm_inference_op()
{
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->inputs().at(1)->get_shape();
auto input = ins->inputs()[0];
auto input_lens = input->get_shape().lens();
std::vector<int64_t> rsp_lens(input_lens.size(), 1);
// for per_activation case, also need to reshape input
if(op.bn_mode == op::batch_norm_inference::per_activation)
{
std::copy(input_lens.begin() + 1, input_lens.end(), rsp_lens.begin() + 1);
}
else
{
rsp_lens[1] = static_cast<int64_t>(old_shape.elements());
}
auto reshape_op = op::reshape{rsp_lens};
std::vector<instruction_ref> reshapes;
std::transform(ins->inputs().begin() + 1,
ins->inputs().end(),
std::back_inserter(reshapes),
[&](auto i) { return mod->insert_instruction(ins, reshape_op, i); });
return mod->replace_instruction(ins,
miopen_batch_norm_inference{op},
input,
reshapes[0],
reshapes[1],
reshapes[2],
reshapes[3],
output);
});
}
// use 0 - input to represent neg // use 0 - input to represent neg
void add_neg_op() void add_neg_op()
{ {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/make_op.hpp"
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
...@@ -43,8 +44,9 @@ ...@@ -43,8 +44,9 @@
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <deque> #include <deque>
#include <variant> #include <variant>
...@@ -78,7 +80,7 @@ struct mlir_handle ...@@ -78,7 +80,7 @@ struct mlir_handle
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); } friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y) { return !(x == y); } friend bool operator!=(ptr x, ptr y) { return not(x == y); }
T obj{}; T obj{};
}; };
...@@ -370,7 +372,11 @@ struct mlir_program ...@@ -370,7 +372,11 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
auto x = prog->make_tensors(outputs); std::vector<shape> reshaped(outputs.size());
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
return *this; return *this;
} }
...@@ -502,11 +508,12 @@ struct mlir_program ...@@ -502,11 +508,12 @@ struct mlir_program
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
std::string tuned = get_tune_params();
if(!tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
// check if HW supports xdlops // check if HW supports xdlops
if(contains(get_xdlops_archs(), target_name)) bool xdlops = contains(get_xdlops_archs(), target_name);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
if(xdlops)
ops.add_attributes({{"xdlopsV2", true}}); ops.add_attributes({{"xdlopsV2", true}});
} }
...@@ -571,7 +578,7 @@ struct mlir_program ...@@ -571,7 +578,7 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
std::string get_tune_params() { return get_mlir_perf_for_conv(pp); } std::string get_tune_params(bool xdlops) { return get_mlir_perf_for_conv(pp, xdlops); }
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
...@@ -589,8 +596,54 @@ std::string dump_mlir(const module& m) ...@@ -589,8 +596,54 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
code_object_op compile_mlir(const context&, const module& m) void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
{ {
auto names = m.get_parameter_names();
std::sort(names.begin(), names.end());
for(auto i : range(names.size()))
{
const auto& name = names[i];
const auto& input = inputs[i]->get_shape();
auto param = m.get_parameter(name);
if(input.standard())
continue;
auto lens = input.lens();
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param);
m.remove_instruction(param);
}
}
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs)
{
adjust_param_shapes(m, inputs);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
...@@ -662,13 +715,19 @@ instruction_ref insert_mlir(module& m, ...@@ -662,13 +715,19 @@ instruction_ref insert_mlir(module& m,
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
code_object_op compile_mlir(const context&, const module&) { return {}; }
template <class T> template <class T>
void use(T&) void use(T&)
{ {
} }
// Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&) insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
......
...@@ -154,7 +154,7 @@ void pack_int8_args::apply(module& m) const ...@@ -154,7 +154,7 @@ void pack_int8_args::apply(module& m) const
bool transa = inputs[0]->get_shape().transposed(); bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed(); bool transb = inputs[1]->get_shape().transposed();
if(!transb) if(not transb)
{ {
auto packed_b = m.insert_instruction( auto packed_b = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}})); ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}}));
......
...@@ -108,16 +108,17 @@ auto query_miopen_db(const std::string& query) ...@@ -108,16 +108,17 @@ auto query_miopen_db(const std::string& query)
} // namespace } // namespace
std::string get_mlir_perf_for_conv(const problem_params& pp) std::string get_mlir_perf_for_conv(const problem_params& pp, bool xdlops)
{ {
std::string query = "select P.* \ std::string solver = xdlops ? "ConvMlirIgemmFwdXdlops" : "ConvMlirIgemmFwd";
std::string query = "select P.* \
from perf_db P, config C \ from perf_db P, config C \
where P.config = C.id AND \ where P.config = C.id AND \
P.solver = 'ConvMlirIgemmFwdXdlops' AND \ P.solver = '${solver}' AND \
${config}"; ${config}";
auto results = auto results = query_miopen_db(
query_miopen_db(interpolate_string(query, {{"config", generate_miopen_config(pp)}})); interpolate_string(query, {{"config", generate_miopen_config(pp)}, {"solver", solver}}));
if(results.empty()) if(results.empty())
return ""; return "";
return results.front().at("params"); return results.front().at("params");
......
...@@ -23,13 +23,62 @@ ...@@ -23,13 +23,62 @@
*/ */
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace { namespace {
template <class Derived, std::size_t N>
struct layernorm_base
{
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"));
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
if(not mods.empty())
{
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
};
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
...@@ -38,60 +87,33 @@ struct find_layernorm ...@@ -38,60 +87,33 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
if(not x_ins->get_shape().standard()) m.replace_instruction(ins, layernorm{eps}, x_ins);
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
} }
}; };
struct find_triaddlayernorm struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
auto add1 = return match::layernorm()(match::var("x")(match::name("add").bind("add")));
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["z1"]; auto add_ins = r.instructions["add"];
auto y_ins = r.instructions["z2"]; auto eps = r.instructions["eps"]->eval().at<float>();
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction( m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
} }
}; };
} // namespace } // namespace
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{}); match::find_matches(m, find_add_layernorm{}, find_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape(), int8_x4_format);
auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
shape miopen_quant_convolution::find(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x_shape = inputs[0];
auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
auto x = to_gpu(generate_argument(x_shape));
auto w = to_gpu(generate_argument(w_shape));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution failed");
solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(cd == nullptr)
cd = make_conv(op);
if(solution_id == 0)
{
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = find(ctx, output_shape, inputs);
if(ws.bytes() > size)
MIGRAPHX_THROW("MIOpen Quant Convolution: workspace has changed during finalization.");
}
auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape);
auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_id);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: compile solution failed");
}
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,7 +21,13 @@ ...@@ -21,7 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <unordered_set>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) ...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb; return rb;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
}
bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.normalize_compute_shape({inputs.at(0)});
}
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
...@@ -109,13 +109,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -109,13 +109,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
insert_pad{}, insert_pad{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
inline_module{}, inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gelu{},
dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
...@@ -134,16 +134,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,16 +134,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -26,15 +26,12 @@ ...@@ -26,15 +26,12 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp> #include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
...@@ -75,84 +72,6 @@ typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::ena ...@@ -75,84 +72,6 @@ typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::ena
return x; return x;
} }
//
// ref implemenataion of batch norm for inference
//
// inputs are:
// args[0] -> input data buffer
// args[1] -> mini batch mean
// args[2] -> mini batch variance
// args[3] -> gamma
// args[4] -> bias
//
// The equation to compute batch norm for inference is:
//
// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
//
// the input data format should be nchw
//
struct ref_batch_norm_inference
{
op::batch_norm_inference op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::batch_norm_inference"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument output{output_shape};
double epsilon = op.epsilon;
auto input = args[0];
auto arg_gamma = args[1];
auto arg_bias = args[2];
auto mini_batch_mean = args[3];
auto mini_batch_variance = args[4];
if(op.bn_mode == op::batch_norm_inference::spatial)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto c = idx[1];
assert((variance[c] + epsilon) > 0);
result[i] =
gamma[c] * (buffer[i] - mean[c]) / std::sqrt(variance[c] + epsilon) +
bias[c];
});
});
}
if(op.bn_mode == op::batch_norm_inference::per_activation)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
idx[0] = 0;
auto index = output_shape.index(idx);
assert((variance[index] + epsilon) > 0);
result[i] = gamma[index] * (buffer[i] - mean[index]) /
std::sqrt(variance[index] + epsilon) +
bias[index];
});
});
}
return output;
}
};
MIGRAPHX_REGISTER_OP(ref_batch_norm_inference)
struct ref_lrn struct ref_lrn
{ {
op::lrn op; op::lrn op;
...@@ -237,16 +156,16 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -237,16 +156,16 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
if(op.use_dynamic_same_auto_pad) if(op.padding_mode != op::padding_mode_t::default_)
{ {
auto input_lens = args[0].get_shape().lens(); auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> img_lens{input_lens.begin() + 2, input_lens.end()};
auto weights_lens = args[1].get_shape().lens(); auto weights_lens = args[1].get_shape().lens();
std::vector<std::size_t> k_lens{weights_lens.begin() + 2, weights_lens.end()}; padding =
padding = calc_dyn_auto_pad(img_lens, k_lens, op.stride, op.dilation); op.padding_mode == op::same_upper
std::cout << "[ "; ? calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, true)
output_shape = : calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, false);
compute_padded_shape({args.at(0).get_shape(), args.at(1).get_shape()}, padding); output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), padding, op.stride, op.dilation);
} }
else else
{ {
...@@ -314,34 +233,6 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -314,34 +233,6 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
}); });
return result; return result;
} }
private:
/*!
* Used for dynamic auto padding since padding needs to be computed at evaulation time.
* \param inputs two fixed shape inputs [input_tensor, weights]
* \param padding from auto_pad calculation
*/
shape compute_padded_shape(const std::vector<shape>& inputs,
const std::vector<std::size_t>& padding) const
{
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
const size_t num_spatial_dims = input.lens().size() - 2;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
{
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + op.dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
op.stride[i] +
1)));
}
return inputs[0].with_lens(output_lens);
}
}; };
struct ref_im2col struct ref_im2col
...@@ -538,65 +429,6 @@ struct ref_quant_gemm ...@@ -538,65 +429,6 @@ struct ref_quant_gemm
}; };
MIGRAPHX_REGISTER_OP(ref_gemm) MIGRAPHX_REGISTER_OP(ref_gemm)
struct leaky_relu_op
{
op::leaky_relu op;
std::string name() const { return "ref::leaky_relu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; };
}
};
struct elu_op
{
op::elu op;
std::string name() const { return "ref::elu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
}
};
template <typename Op>
struct ref_unary : auto_register_op<ref_unary<Op>>
{
ref_unary() = default;
template <class T>
ref_unary(T pop) : op(Op{std::move(pop)})
{
}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
assert(input.get_shape().standard());
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
return result;
}
};
template <class Op> template <class Op>
struct ref_softmax : auto_register_op<ref_softmax<Op>> struct ref_softmax : auto_register_op<ref_softmax<Op>>
{ {
...@@ -732,16 +564,12 @@ struct ref_apply ...@@ -732,16 +564,12 @@ struct ref_apply
void init() void init()
{ {
apply_map["batch_norm_inference"] =
extend_op<ref_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>(); apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>(); apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>(); apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] = apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>(); extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["elu"] = extend_op<ref_unary<elu_op>, op::elu>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>(); apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["leaky_relu"] = extend_op<ref_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>(); apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
apply_map["pad"] = extend_op<ref_pad, op::pad>(); apply_map["pad"] = extend_op<ref_pad, op::pad>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment