Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
......@@ -13,12 +13,14 @@ void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta);
float beta,
bool int8_x4_format);
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta);
int32_t beta,
bool int8_x4_format);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOOP_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/loop.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_loop
{
op::loop op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::loop"; }
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const;
argument
compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -94,7 +94,7 @@ inline convolution_descriptor make_conv(const T& op)
std::vector<int> stride(std::max(2, kdims), 1);
std::vector<int> dilation(std::max(2, kdims), 1);
std::copy_backward(op.padding.begin(), op.padding.end(), padding.end());
std::copy_backward(op.padding.begin(), op.padding.begin() + kdims, padding.end());
std::copy_backward(op.stride.begin(), op.stride.end(), stride.end());
std::copy_backward(op.dilation.begin(), op.dilation.end(), dilation.end());
......@@ -145,7 +145,7 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
std::vector<int> stride(std::max(2, kdims), 1);
std::vector<int> lengths(std::max(2, kdims), 1);
std::copy_backward(op.padding.begin(), op.padding.end(), padding.end());
std::copy_backward(op.padding.begin(), op.padding.begin() + kdims, padding.end());
std::copy_backward(op.stride.begin(), op.stride.end(), stride.end());
std::copy_backward(op.lengths.begin(), op.lengths.end(), lengths.end());
......
#ifndef MIGRAPHX_GUARD_RTGLIB_MULTINOMIAL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MULTINOMIAL_HPP
#include <migraphx/op/multinomial.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_multinomial
{
op::multinomial op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::multinomial"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_NONZERO_HPP
#define MIGRAPHX_GUARD_RTGLIB_NONZERO_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_nonzero
{
op::nonzero op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::nonzero"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -13,7 +13,7 @@ namespace gpu {
struct pack_int8_args
{
std::string name() const { return "gpu::pack_int8_args"; }
void apply(module& p) const;
void apply(module& m) const;
shape pack_int8_shape(const shape& s) const;
};
......
#ifndef MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP
#define MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP
#include <migraphx/gpu/name.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/op/prefix_scan_sum.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum>
{
op::prefix_scan_sum op;
template <class Self, class T>
static auto reflect(Self& self, T f)
{
return migraphx::reflect(self.op, f);
}
shape compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes{inputs};
in_shapes.pop_back();
check_shapes{in_shapes, *this}.standard();
return op.normalize_compute_shape(in_shapes);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
if(op.exclusive or op.reverse)
MIGRAPHX_THROW("Exclusive and reverse scan not supported");
device::prefix_scan_sum(ctx.get_stream().get(), args[1], args[0], op.axis);
return args[1];
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_REVERSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_REVERSE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reverse
{
op::reverse op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reverse"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_scatter
{
op::scatter op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::scatter"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_TOPK_HPP
#define MIGRAPHX_GUARD_RTGLIB_TOPK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_topk
{
op::topk op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::topk"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/where.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_where : ternary_device<hip_where, device::where>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -5,10 +5,25 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape pack_int8_shape(const shape& s)
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0);
return pack_int8_shape(inputs.at(0));
}
argument
......
......@@ -43,6 +43,59 @@ constexpr bool is_sorted(Iterator first, Iterator last, Compare comp)
return is_sorted_until(first, last, comp) == last;
}
template <class Iterator, class F>
constexpr F for_each(Iterator first, Iterator last, F f)
{
for(; first != last; ++first)
{
f(*first);
}
return f;
}
template <class Iterator, class Predicate>
constexpr Iterator find_if(Iterator first, Iterator last, Predicate p)
{
for(; first != last; ++first)
{
if(p(*first))
{
return first;
}
}
return last;
}
template <class Iterator, class T>
constexpr Iterator find(Iterator first, Iterator last, const T& value)
{
return find_if(first, last, [&](const auto& x) { return x == value; });
}
template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{
for(;; ++first)
{
Iterator1 it = first;
for(Iterator2 s_it = s_first;; ++it, ++s_it)
{
if(s_it == s_last)
{
return first;
}
if(it == last)
{
return last;
}
if(!(*it == *s_it))
{
break;
}
}
}
}
} // namespace migraphx
#endif
......@@ -2,57 +2,26 @@
#define MIGRAPHX_GUARD_KERNELS_ARGS_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx {
template <std::size_t N>
struct arg
{
};
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
// Use template specialization since ADL is broken on hcc
template <std::size_t>
template <index_int>
struct make_tensor;
template <class F, std::size_t... Ns, class... Ts>
__device__ auto make_tensors_impl(F f, seq<Ns...>, Ts*... xs)
template <class F, index_int... Ns, class... Ts>
__device__ auto make_tensors_impl(F f, detail::seq<Ns...>, Ts*... xs)
{
f(make_tensor<Ns>::apply(xs)...);
return f(make_tensor<Ns>::apply(xs)...);
}
template <class... Ts>
__device__ auto make_tensors(Ts*... xs)
inline __device__ auto make_tensors()
{
return [=](auto f) { make_tensors_impl(f, gens<sizeof...(Ts)>{}, xs...); };
return [](auto*... xs) {
return [=](auto f) { return make_tensors_impl(f, detail::gens<sizeof...(xs)>{}, xs...); };
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_ARGS_HPP
\ No newline at end of file
#endif // MIGRAPHX_GUARD_KERNELS_ARGS_HPP
......@@ -2,7 +2,8 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#include <migraphx/kernels/types.hpp>
#include <type_traits>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
......@@ -41,8 +42,16 @@ template <class T, index_int N>
struct array
{
T d[N];
constexpr T& operator[](index_int i) { return d[i]; }
constexpr const T& operator[](index_int i) const { return d[i]; }
constexpr T& operator[](index_int i)
{
MIGRAPHX_ASSERT(i < N);
return d[i];
}
constexpr const T& operator[](index_int i) const
{
MIGRAPHX_ASSERT(i < N);
return d[i];
}
constexpr T& front() { return d[0]; }
constexpr const T& front() const { return d[0]; }
......@@ -53,7 +62,7 @@ struct array
constexpr T* data() { return d; }
constexpr const T* data() const { return d; }
constexpr std::integral_constant<index_int, N> size() const { return {}; }
constexpr index_constant<N> size() const { return {}; }
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
......@@ -142,6 +151,18 @@ struct array
result[0] += overflow;
return result;
}
template <class Stream>
friend constexpr const Stream& operator<<(const Stream& ss, const array& a)
{
for(index_int i = 0; i < N; i++)
{
if(i > 0)
ss << ", ";
ss << a[i];
}
return ss;
}
};
template <class T, T... xs>
......@@ -151,6 +172,18 @@ struct integral_const_array : array<T, sizeof...(xs)>
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {}
};
template <class T, T... xs, class F>
constexpr auto transform(integral_const_array<T, xs...>, F f)
{
return integral_const_array<T, f(xs)...>{};
}
template <class T, T... xs, class U, U... ys, class F>
constexpr auto transform(integral_const_array<T, xs...>, integral_const_array<U, ys...>, F f)
{
return integral_const_array<T, f(xs, ys)...>{};
}
template <index_int... Ns>
using index_ints = integral_const_array<index_int, Ns...>;
......
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <hip/hip_runtime.h>
namespace migraphx {
inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function)
{
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort();
}
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \
assert_fail(xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
#else
#define MIGRAPHX_ASSERT(cond)
#endif
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/array.hpp>
namespace migraphx {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <index_int>
using ignore = swallow;
namespace detail {
template <class R>
struct eval_helper
{
R result;
template <class F, class... Ts>
constexpr eval_helper(const F& f, Ts&&... xs) : result(f(static_cast<Ts>(xs)...))
{
}
};
template <>
struct eval_helper<void>
{
int result;
template <class F, class... Ts>
constexpr eval_helper(const F& f, Ts&&... xs) : result((f(static_cast<Ts>(xs)...), 0))
{
}
};
template <index_int...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <index_int... Xs, index_int... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <index_int N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
template <class F, index_int... Ns>
constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
{
return f(index_constant<Ns>{}...);
}
template <index_int... N>
constexpr auto args_at(seq<N...>)
{
return [](ignore<N>..., auto x, auto...) { return x; };
}
} // namespace detail
template <class T>
constexpr auto always(T x)
{
return [=](auto&&...) { return x; };
}
template <index_int N, class F>
constexpr auto sequence_c(F&& f)
{
return detail::sequence_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class G>
constexpr auto by(F f, G g)
{
return [=](auto... xs) {
return detail::eval_helper<decltype(g(f(xs)...))>{g, f(xs)...}.result;
};
}
template <class F>
constexpr auto by(F f)
{
return by([=](auto x) { return (f(x), 0); }, always(0));
}
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
swallow{(f(std::forward<Ts>(xs)), 0)...};
}
template <class F>
constexpr void each_args(F)
{
}
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
template <index_int N>
constexpr auto arg_c()
{
return [](auto... xs) { return detail::args_at(detail::gens<N>{})(xs...); };
}
template <class IntegralConstant>
constexpr auto arg(IntegralConstant ic)
{
return arg_c<ic>();
}
inline constexpr auto rotate_last()
{
return [](auto... xs) {
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
}
template <class F>
constexpr auto transform_args(F f)
{
return [=](auto... xs) {
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
}
template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs)
{
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
}
#define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); })
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
......@@ -12,9 +12,43 @@ struct index
index_int local = 0;
index_int group = 0;
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nglobal() const
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
return blockDim.x * gridDim.x;
#endif
}
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT
__device__ index_int nlocal() const
{
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
#else
return blockDim.x;
#endif
}
template <class F>
__device__ void global_stride(index_int n, F f) const
{
const auto stride = nglobal();
for(index_int i = global; i < n; i += stride)
{
f(i);
}
}
template <class F>
__device__ void local_stride(index_int n, F f) const
{
const auto stride = nlocal();
for(index_int i = local; i < n; i += stride)
{
f(i);
}
}
};
inline __device__ index make_index()
......
#ifndef MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
using value_type = T;
using type = integral_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
};
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T v, class U, U w> \
constexpr inline integral_constant<decltype(v op w), (v op w)> operator op( \
integral_constant<T, v>, integral_constant<U, w>) noexcept \
{ \
return {}; \
}
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T v> \
constexpr inline integral_constant<decltype(op v), (op v)> operator op( \
integral_constant<T, v>) noexcept \
{ \
return {}; \
}
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(-)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(*)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(/)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(%)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>>)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<<)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP (^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(!)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-)
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
template <index_int N>
using index_constant = integral_constant<index_int, N>;
template <auto v>
static constexpr auto _c = integral_constant<decltype(v), v>{};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp>
namespace migraphx {
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...);
});
});
}
template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps)
{
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) {
auto idx = make_index();
pointwise_tensor(idx, f, xs...);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment