Commit 3eaa0969 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/bert-opt' into HEAD

parents e3d0c287 22aa9c5e
...@@ -30,14 +30,14 @@ ...@@ -30,14 +30,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
struct prefuse_ops struct prefuse_ops
{ {
std::string name() const { return "gpu::prefuse_ops"; } std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const; void apply(module_pass_manager& m) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -118,16 +118,14 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,16 +118,14 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
...@@ -166,7 +164,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -166,7 +164,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto reduce_elements = get_reduce_elements(ins->inputs()); auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384) if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})"; v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
......
...@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs) ...@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs)
} }
else else
{ {
using vec_type = std::remove_reference_t<decltype(array2vec(x))>; using vec_type = remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...); f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
} }
} }
......
...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l ...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...) #define MIGRAPHX_WARN(...)
#endif #endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
...@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x) ...@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x)
} }
return output.data; return output.data;
} }
#endif #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP #endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx { namespace migraphx {
template <class T> template <class T>
...@@ -53,22 +53,16 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,22 +53,16 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(), std::size_t num_slices =
indices_shape_lens.end() - 1, accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
std::multiplies<std::size_t>()); op::product{});
const std::size_t num_batches = accumulate(data_shape_lens.begin(), const std::size_t num_batches =
data_shape_lens.begin() + batch_dims, accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
1, const std::size_t data_batch_stride =
std::multiplies<std::size_t>()); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
...@@ -83,7 +77,7 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -83,7 +77,7 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
...@@ -91,7 +85,7 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -91,7 +85,7 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
std::multiplies<std::size_t>()); op::product{});
relative_slice_offset += index * size_from_slice_dims; relative_slice_offset += index * size_from_slice_dims;
} }
......
...@@ -24,11 +24,18 @@ ...@@ -24,11 +24,18 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP #ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy // Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY) #if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT #undef __HIP_PLATFORM_HCC__ // NOLINT
#endif #endif
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP #endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
...@@ -163,6 +163,30 @@ struct index ...@@ -163,6 +163,30 @@ struct index
} }
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride_loop_unroll(index_int start, N n, Stride stride, F f)
{
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
}
template <class F, class N, class Stride>
static constexpr void for_stride_loop(index_int start, N n, Stride stride, F f)
{
index_int k = 0;
for(index_int i = start; i < n; i += stride)
{
invoke_loop(f, i, k);
k++;
}
}
template <bool Unroll, class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride); MIGRAPHX_ASSERT(start < stride);
...@@ -180,46 +204,40 @@ struct index ...@@ -180,46 +204,40 @@ struct index
invoke_loop(f, start, _c<0>); invoke_loop(f, start, _c<0>);
} }
} }
else else if constexpr(Unroll)
{ {
static_assert(max_stride_iterations(n, stride) < 128); MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256)
sequence(max_stride_iterations(n, stride), [&](auto... ks) { {
fold([&](auto d, auto k) { for_stride_loop_unroll(start, n, stride, f);
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
} }
} }
else else
{ {
index_int k = 0; for_stride_loop(start, n, stride, f);
for(index_int i = start; i < n; i += stride) }
{
invoke_loop(f, i, k);
k++;
} }
else
{
for_stride_loop(start, n, stride, f);
} }
} }
template <class F, class N> template <class F, class N>
__device__ void global_stride(N n, F f) const __device__ void global_stride(N n, F f) const
{ {
for_stride(global, n, nglobal(), f); for_stride<false>(global, n, nglobal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void local_stride(N n, F f) const __device__ void local_stride(N n, F f) const
{ {
for_stride(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void group_stride(N n, F f) const __device__ void group_stride(N n, F f) const
{ {
for_stride(group, n, ngroup(), f); for_stride<false>(group, n, ngroup(), f);
} }
}; };
......
...@@ -28,8 +28,7 @@ ...@@ -28,8 +28,7 @@
#include <migraphx/kernels/vec.hpp> #include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h> #include <migraphx/kernels/hip.hpp>
#include <hip/math_functions.h>
namespace migraphx { namespace migraphx {
...@@ -222,7 +221,7 @@ constexpr auto min(const T& a, const U& b) ...@@ -222,7 +221,7 @@ constexpr auto min(const T& a, const U& b)
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})> template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x) constexpr T sin(T x)
{ {
constexpr const T shift = M_PI_2; constexpr const T shift = HIP_PIO2_F;
return migraphx::cos(shift - x); return migraphx::cos(shift - x);
} }
......
...@@ -66,13 +66,22 @@ struct convert_to ...@@ -66,13 +66,22 @@ struct convert_to
} }
}; };
template <index_int N>
struct mean struct mean
{ {
index_int item_num = 1;
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x) const
{
using type = vec_type<T>;
if constexpr(is_floating_point<type>{})
{ {
return x / static_cast<T>(item_num); constexpr type d = 1.0 / N;
return x * d;
}
else
{
return x / static_cast<type>(N);
}
} }
}; };
......
...@@ -391,22 +391,40 @@ struct block ...@@ -391,22 +391,40 @@ struct block
struct lane struct lane
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return f(j, d);
using type = typename decltype(x)::type; }
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < n; j++)
{ {
r = op(r, read(x[j], xs[j]...)); r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
} }
return r; return r;
});
} }
template <class F> template <class F>
...@@ -415,29 +433,25 @@ struct lane ...@@ -415,29 +433,25 @@ struct lane
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { for(index_int j = 0; j < n; j++)
for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(xs(j, _c<0>)...);
} }
});
} }
template <class Input> template <class R, class F, class N, class... Ts>
constexpr auto elements() const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
using reduce_type = decltype(slice(Input{})); return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return get_shape_c<reduce_type>{}.elements();
} }
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
......
...@@ -76,14 +76,6 @@ struct shape ...@@ -76,14 +76,6 @@ struct shape
constexpr index_int index(index_array x) const { return x.dot(strides); } constexpr index_int index(index_array x) const { return x.dot(strides); }
constexpr index_int index(std::initializer_list<index_int> x) const
{
index_int idx = 0;
for(index_int i = 0; i < x.size(); i++)
idx += *(x.begin() + i) * strides[i];
return idx;
}
constexpr index_int index(index_int i) const constexpr index_int index(index_int i) const
{ {
if(this->standard()) if(this->standard())
......
...@@ -28,8 +28,45 @@ ...@@ -28,8 +28,45 @@
namespace migraphx { namespace migraphx {
using index_int = std::uint32_t; #if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
using diff_int = std::int32_t; using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
#elif defined(MIGRAPHX_USE_HIPRTC)
using int8_t = __hip_int8_t;
using uint8_t = __hip_uint8_t;
using int16_t = __hip_int16_t;
using uint16_t = __hip_uint16_t;
using int32_t = __hip_int32_t;
using uint32_t = __hip_uint32_t;
using int64_t = __hip_int64_t;
using uint64_t = __hip_uint64_t;
#else
using int8_t = std::int8_t;
using uint8_t = std::uint8_t;
using int16_t = std::int16_t;
using uint16_t = std::uint16_t;
using int32_t = std::int32_t;
using uint32_t = std::uint32_t;
using int64_t = std::int64_t;
using uint64_t = std::uint64_t;
#endif // MIGRAPHX_USE_HIPRTC
using index_int = uint32_t;
using diff_int = int32_t;
static_assert(sizeof(int8_t) == 1, "int8_t must be 1 bytes");
static_assert(sizeof(uint8_t) == 1, "uint8_t must be 1 bytes");
static_assert(sizeof(int16_t) == 2, "int16_t must be 2 bytes");
static_assert(sizeof(uint16_t) == 2, "uint16_t must be 2 bytes");
static_assert(sizeof(int32_t) == 4, "int32_t must be 4 bytes");
static_assert(sizeof(uint32_t) == 4, "uint32_t must be 4 bytes");
static_assert(sizeof(int64_t) == 8, "int64_t must be 8 bytes");
static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes");
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT #define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -90,7 +92,9 @@ struct find_layernorm ...@@ -90,7 +92,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>(); float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{eps}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
...@@ -100,25 +104,26 @@ struct find_add_layernorm ...@@ -100,25 +104,26 @@ struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
return match::layernorm()(match::var("x")(match::name("add").bind("add"))); return match::name("gpu::prelayernorm")(
match::args(match::name("add")(match::used_once()).bind("add")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
float eps = 0; auto op = any_cast<layernorm>(ins->get_operator());
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module_pass_manager& mpm) const
{ {
match::find_matches(m, find_add_layernorm{}, find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
#include <migraphx/layout_nhwc.hpp> #include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize.hpp> #include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
...@@ -121,18 +121,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -121,18 +121,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gelu{}, rewrite_gelu{},
optimize{}, optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{}, dead_code_elimination{},
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, optimize_module{},
propagate_constant{},
dead_code_elimination{},
fuse_ck_gemm_softmax_gemm{&ctx},
dead_code_elimination{},
optimize{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
fuse_mlir{&ctx}, fuse_mlir{&ctx},
......
...@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs) ...@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs)
}; };
} }
template <class Op>
struct ref_convolution : auto_register_op<ref_convolution<Op>>
{
ref_convolution() = default;
ref_convolution(Op pop) : op(std::move(pop)) {}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> padding;
if(op.padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
padding =
op.padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), padding, op.stride, op.dilation);
}
else
{
padding = op.padding;
if(output_shape.dynamic())
{
output_shape =
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0];
auto wei_c = wei_lens[1];
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto w = idx_o[1];
auto n_dim = idx_o.size();
std::vector<std::ptrdiff_t> win_start;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) -
std::ptrdiff_t(padding[d_2]));
}
const auto group_id = w / (wei_n / op.group);
shape win_shape{output_shape.type(), win_size};
double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) {
auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch;
std::transform(idx_win.begin() + 1,
idx_win.end(),
win_start.begin(),
idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(),
idx.end(),
in_lens.begin(),
in_lens.end(),
std::less<std::ptrdiff_t>{}))
{
acc +=
input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
}
});
output[i] = acc;
});
});
return result;
}
};
struct ref_im2col struct ref_im2col
{ {
op::im2col op; op::im2col op;
...@@ -564,11 +461,8 @@ struct ref_apply ...@@ -564,11 +461,8 @@ struct ref_apply
void init() void init()
{ {
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>(); apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>(); apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>(); apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>(); apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
......
##################################################################################### # ####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### # ####################################################################################
cmake_policy(SET CMP0057 NEW) cmake_policy(SET CMP0057 NEW)
...@@ -49,9 +49,11 @@ function(add_test_command NAME EXE) ...@@ -49,9 +49,11 @@ function(add_test_command NAME EXE)
set_tests_properties(${NAME} PROPERTIES DISABLED On) set_tests_properties(${NAME} PROPERTIES DISABLED On)
elseif(WIN32) elseif(WIN32)
set(WINPATH) set(WINPATH)
foreach(PATH ${CMAKE_FIND_ROOT_PATH}) foreach(PATH ${CMAKE_FIND_ROOT_PATH})
list(APPEND WINPATH ${PATH}/bin) list(APPEND WINPATH ${PATH}/bin)
endforeach() endforeach()
file(GENERATE OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/test_${NAME}.cmd" file(GENERATE OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/test_${NAME}.cmd"
CONTENT "set PATH=${WINPATH};%PATH% CONTENT "set PATH=${WINPATH};%PATH%
%1 ${ARGN}") %1 ${ARGN}")
...@@ -67,9 +69,11 @@ function(add_test_command NAME EXE) ...@@ -67,9 +69,11 @@ function(add_test_command NAME EXE)
# --args $<TARGET_FILE:${EXE}> ${ARGN}) # --args $<TARGET_FILE:${EXE}> ${ARGN})
set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME}) set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME})
file(MAKE_DIRECTORY ${TEST_DIR}) file(MAKE_DIRECTORY ${TEST_DIR})
if (NOT EXISTS ${TEST_DIR})
if(NOT EXISTS ${TEST_DIR})
message(FATAL_ERROR "Failed to create test directory: ${TEST_DIR}") message(FATAL_ERROR "Failed to create test directory: ${TEST_DIR}")
endif() endif()
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake" file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT " CONTENT "
# Remove previous core dump # Remove previous core dump
...@@ -90,22 +94,27 @@ function(add_test_command NAME EXE) ...@@ -90,22 +94,27 @@ function(add_test_command NAME EXE)
add_test(NAME ${NAME} COMMAND ${EXE} ${ARGN}) add_test(NAME ${NAME} COMMAND ${EXE} ${ARGN})
endif() endif()
endif() endif()
set_tests_properties(${NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED") set_tests_properties(${NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED")
endfunction() endfunction()
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
add_executable (${TEST_NAME} EXCLUDE_FROM_ALL ${ARGN}) add_executable(${TEST_NAME} EXCLUDE_FROM_ALL ${ARGN})
target_link_libraries(${TEST_NAME} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TEST_NAME} ${CMAKE_THREAD_LIBS_INIT})
# Cmake does not add flags correctly for gcc # Cmake does not add flags correctly for gcc
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") if(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread) set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread)
endif() endif()
separate_arguments(MIOPEN_TEST_FLAGS_ARGS UNIX_COMMAND ${MIOPEN_TEST_FLAGS}) separate_arguments(MIOPEN_TEST_FLAGS_ARGS UNIX_COMMAND ${MIOPEN_TEST_FLAGS})
if(MIOPEN_TEST_ALL) if(MIOPEN_TEST_ALL)
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} --all ${MIOPEN_TEST_FLAGS_ARGS}) set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} --all ${MIOPEN_TEST_FLAGS_ARGS})
else() else()
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} ${MIOPEN_TEST_FLAGS_ARGS}) set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} ${MIOPEN_TEST_FLAGS_ARGS})
endif() endif()
add_test_command(${TEST_NAME} ${TEST_COMMAND}) add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
...@@ -133,7 +142,7 @@ if(MIGRAPHX_ENABLE_GPU) ...@@ -133,7 +142,7 @@ if(MIGRAPHX_ENABLE_GPU)
COST 10 COST 10
RESOURCE_LOCK gpu RESOURCE_LOCK gpu
) )
target_link_libraries(test_gpu_${BASE_NAME} migraphx_gpu) target_link_libraries(test_gpu_${BASE_NAME} migraphx_gpu migraphx_kernels)
endforeach() endforeach()
endif() endif()
...@@ -155,7 +164,8 @@ endif() ...@@ -155,7 +164,8 @@ endif()
# Onnx test # Onnx test
set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx) set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp) file(GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
foreach(ONNX_TEST ${ONNX_TESTS}) foreach(ONNX_TEST ${ONNX_TESTS})
get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
...@@ -180,12 +190,12 @@ add_dependencies(check test_tf) ...@@ -180,12 +190,12 @@ add_dependencies(check test_tf)
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(verify) add_subdirectory(verify)
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py) add_subdirectory(py)
endif() endif()
function(test_header NAME HEADER) function(test_header NAME HEADER)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp
"#include <${HEADER}>\nint main() {}\n" "#include <${HEADER}>\nint main() {}\n"
) )
...@@ -206,6 +216,7 @@ function(test_headers PREFIX) ...@@ -206,6 +216,7 @@ function(test_headers PREFIX)
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
get_filename_component(BASE_NAME ${HEADER} NAME_WE) get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp) test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(header_${TEST_NAME} migraphx_gpu) target_link_libraries(header_${TEST_NAME} migraphx_gpu)
endif() endif()
...@@ -214,6 +225,7 @@ endfunction() ...@@ -214,6 +225,7 @@ endfunction()
test_headers(migraphx ${CMAKE_SOURCE_DIR}/src/include/migraphx/*.hpp) test_headers(migraphx ${CMAKE_SOURCE_DIR}/src/include/migraphx/*.hpp)
test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/ref/*.hpp) test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/ref/*.hpp)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
test_headers(migraphx/gpu ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/*.hpp) test_headers(migraphx/gpu ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/*.hpp)
endif() endif()
...@@ -30,7 +30,6 @@ TEST_CASE(load_save_default) ...@@ -30,7 +30,6 @@ TEST_CASE(load_save_default)
std::string filename = "migraphx_api_load_save.mxr"; std::string filename = "migraphx_api_load_save.mxr";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes(); auto s1 = p1.get_output_shapes();
migraphx::save(p1, filename.c_str()); migraphx::save(p1, filename.c_str());
auto p2 = migraphx::load(filename.c_str()); auto p2 = migraphx::load(filename.c_str());
auto s2 = p2.get_output_shapes(); auto s2 = p2.get_output_shapes();
......
...@@ -35,13 +35,14 @@ ...@@ -35,13 +35,14 @@
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx_kernels.hpp>
// NOLINTNEXTLINE // NOLINTNEXTLINE
const std::string write_2s = R"__migraphx__( const std::string write_2s = R"__migraphx__(
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
extern "C" { extern "C" {
__global__ void write(int8_t* data) __global__ void write(char* data)
{ {
int num = threadIdx.x + blockDim.x * blockIdx.x; int num = threadIdx.x + blockDim.x * blockIdx.x;
data[num] = 2; data[num] = 2;
...@@ -58,7 +59,7 @@ const std::string add_2s_binary = R"__migraphx__( ...@@ -58,7 +59,7 @@ const std::string add_2s_binary = R"__migraphx__(
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
extern "C" { extern "C" {
__global__ void add_2(std::int8_t* x, std::int8_t* y) __global__ void add_2(char* x, char* y)
{ {
int num = threadIdx.x + blockDim.x * blockIdx.x; int num = threadIdx.x + blockDim.x * blockIdx.x;
y[num] = x[num] + 2; y[num] = x[num] + 2;
...@@ -137,7 +138,8 @@ int main() {} ...@@ -137,7 +138,8 @@ int main() {}
const std::string math_template = R"__migraphx__( const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp> #include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp>
using namespace migraphx;
extern "C" { extern "C" {
__global__ void kernel(${type}* p) __global__ void kernel(${type}* p)
{ {
......
c9a53c925510a101f5ca94d5ecda0924e40a8463
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment