Commit d0202590 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'test_runner_match_input_output' into migraphx_for_ort

parents 2e43e30b 414ea291
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct precompile_op
{
operation op = op::identity{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::precompile_op"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
inputs.pop_back();
return op.compute_shape(inputs, mods);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
MIGRAPHX_REGISTER_OP(precompile_op);
struct pointwise_compiler
{
std::string name() const { return "pointwise"; }
operation apply(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
return compile_pointwise(ctx, to_shapes(ins->inputs()), *pm);
}
};
using compiler_function = std::function<operation(context&, instruction_ref, operation)>;
template <class T>
compiler_function make_compiler_function(T x)
{
return {[=](auto&&... xs) { return x.apply(xs...); }};
}
template <class... Ts>
std::unordered_map<std::string, compiler_function> make_compilers(Ts... xs)
{
return {{xs.name(), make_compiler_function(xs)}...};
}
struct compiled_result
{
operation op;
instruction_ref ins;
};
void compile_ops::apply(module& m) const
{
auto compilers = make_compilers(pointwise_compiler{});
std::vector<std::function<compiled_result()>> compiles;
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
assert(contains(compilers, preop.name()));
auto c = compilers[preop.name()];
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
}
std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,6 +22,8 @@ static const char* const pointwise_kernel = R"__migraphx__( ...@@ -17,6 +22,8 @@ static const char* const pointwise_kernel = R"__migraphx__(
using namespace migraphx; using namespace migraphx;
${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void kernel(${params})
{ {
...@@ -29,7 +36,10 @@ int main() {} ...@@ -29,7 +36,10 @@ int main() {}
)__migraphx__"; )__migraphx__";
operation compile_pointwise(context&, const std::vector<shape>& inputs, const std::string& lambda) operation compile_pointwise(context&,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble)
{ {
hip_compile_options options; hip_compile_options options;
options.global = compute_global(inputs.front().elements()); options.global = compute_global(inputs.front().elements());
...@@ -37,13 +47,23 @@ operation compile_pointwise(context&, const std::vector<shape>& inputs, const st ...@@ -37,13 +47,23 @@ operation compile_pointwise(context&, const std::vector<shape>& inputs, const st
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.reduced_inputs = reduce_dims(inputs); options.reduced_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", lambda}}); {"lambda", lambda},
{"preamble", preamble}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m)
{
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"}));
return compile_pointwise((ctx), inputs, "&" + name, g.str());
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -14,17 +14,29 @@ namespace gpu { ...@@ -14,17 +14,29 @@ namespace gpu {
static const char* const roialign_kernel = R"__migraphx__( static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp> #include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
using namespace migraphx; namespace migraphx {
extern "C" { extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y) __global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{ {
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) { roialign(xs...); }); make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
_c<bool{IS_AVG_POOLING}>,
_c<int64_t{SAMPLING_RATIO}>,
MIGRAPHX_MAKE_CONSTANT(float{SPATIAL_SCALE}));
roialign(xs..., settings);
});
} }
} }
} // namespace migraphx
int main() {} int main() {}
)__migraphx__"; )__migraphx__";
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct context;
struct compile_ops
{
context* ctx = nullptr;
std::string name() const { return "gpu::compile_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
...@@ -6,11 +6,17 @@ ...@@ -6,11 +6,17 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu { namespace gpu {
struct context; struct context;
operation operation compile_pointwise(context& ctx,
compile_pointwise(context& ctx, const std::vector<shape>& inputs, const std::string& lambda); const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble = "");
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -176,23 +176,23 @@ struct array ...@@ -176,23 +176,23 @@ struct array
} }
}; };
template <class T, T... xs> template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(xs)> struct integral_const_array : array<T, sizeof...(Xs)>
{ {
using base_array = array<T, sizeof...(xs)>; using base_array = array<T, sizeof...(Xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {} MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {}
}; };
template <class T, T... xs, class F> template <class T, T... Xs, class F>
constexpr auto transform(integral_const_array<T, xs...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, F f)
{ {
return integral_const_array<T, f(xs)...>{}; return integral_const_array<T, f(Xs)...>{};
} }
template <class T, T... xs, class U, U... ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, xs...>, integral_const_array<U, ys...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{ {
return integral_const_array<T, f(xs, ys)...>{}; return integral_const_array<T, f(Xs, Ys)...>{};
} }
template <index_int... Ns> template <index_int... Ns>
......
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
inline __host__ __device__ void // Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN
#else
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
// noreturn cannot be used on this function because abort in hip is broken
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function) assert_fail(const char* assertion, const char* file, unsigned int line, const char* function)
{ {
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion); printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion);
......
...@@ -168,6 +168,7 @@ constexpr auto transform_args(F f, Fs... fs) ...@@ -168,6 +168,7 @@ constexpr auto transform_args(F f, Fs... fs)
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
} }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); }) ([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); })
......
#ifndef MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
namespace migraphx {
template <class F>
struct generic_constant
{
static constexpr auto value = F{}();
using value_type = decltype(value);
using type = generic_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
};
template <class F>
constexpr generic_constant<F> make_generic_constant(F)
{
return {};
}
// NOLINTNEXTLINE
#define MIGRAPHX_MAKE_CONSTANT(x) \
make_generic_constant([] { \
struct fun \
{ \
constexpr auto operator()() const { return x; } \
}; \
return fun{}; \
}())
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP #ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP #define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
namespace migraphx { namespace migraphx {
...@@ -17,7 +17,7 @@ struct index ...@@ -17,7 +17,7 @@ struct index
#ifdef MIGRAPHX_NGLOBAL #ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL; return MIGRAPHX_NGLOBAL;
#else #else
return blockDim.x * gridDim.x; return blockDim.x * gridDim.x; // NOLINT
#endif #endif
} }
...@@ -26,7 +26,7 @@ struct index ...@@ -26,7 +26,7 @@ struct index
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL; return MIGRAPHX_NLOCAL;
#else #else
return blockDim.x; return blockDim.x; // NOLINT
#endif #endif
} }
...@@ -53,7 +53,7 @@ struct index ...@@ -53,7 +53,7 @@ struct index
inline __device__ index make_index() inline __device__ index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -5,28 +5,30 @@ ...@@ -5,28 +5,30 @@
namespace migraphx { namespace migraphx {
template <class T, T v> template <class T, T V>
struct integral_constant struct integral_constant
{ {
static constexpr T value = v; static constexpr T value = V;
using value_type = T; using value_type = T;
using type = integral_constant; using type = integral_constant;
constexpr operator value_type() const noexcept { return value; } constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; } constexpr value_type operator()() const noexcept { return value; }
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \ #define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T v, class U, U w> \ template <class T, T V, class U, U w> \
constexpr inline integral_constant<decltype(v op w), (v op w)> operator op( \ constexpr inline integral_constant<decltype(V op w), (V op w)> operator op( \
integral_constant<T, v>, integral_constant<U, w>) noexcept \ integral_constant<T, V>, integral_constant<U, w>) noexcept \
{ \ { \
return {}; \ return {}; \
} }
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \ #define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T v> \ template <class T, T V> \
constexpr inline integral_constant<decltype(op v), (op v)> operator op( \ constexpr inline integral_constant<decltype(op V), (op V)> operator op( \
integral_constant<T, v>) noexcept \ integral_constant<T, V>) noexcept \
{ \ { \
return {}; \ return {}; \
} }
...@@ -64,8 +66,8 @@ using false_type = bool_constant<false>; ...@@ -64,8 +66,8 @@ using false_type = bool_constant<false>;
template <index_int N> template <index_int N>
using index_constant = integral_constant<index_int, N>; using index_constant = integral_constant<index_int, N>;
template <auto v> template <auto V>
static constexpr auto _c = integral_constant<decltype(v), v>{}; static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP #endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
...@@ -23,7 +23,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) ...@@ -23,7 +23,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts> template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps) __device__ void pointwise(F f, Ts*... ps)
{ {
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize()); auto t = transform_args(make_tensors(), rotate_last());
t(ps...)([&](auto... xs) { t(ps...)([&](auto... xs) {
auto idx = make_index(); auto idx = make_index();
pointwise_tensor(idx, f, xs...); pointwise_tensor(idx, f, xs...);
......
...@@ -14,9 +14,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -14,9 +14,7 @@ constexpr auto traverse_preload(Shapes... ss)
auto each = [&](auto x) { auto each = [&](auto x) {
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted()) if constexpr(not s.broadcasted() or (s.elements() - size) < 64)
return f(x, offset, false_type{});
else if constexpr((s.elements() - size) < 64)
return f(x, offset, false_type{}); return f(x, offset, false_type{});
else else
{ {
......
#ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP #ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#define MIGRAPHX_GUARD_KERNELS_PRINT_HPP #define MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp> #include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/basic_ops.hpp>
#include <args.hpp> #include <migraphx/kernels/array.hpp>
namespace migraphx { namespace migraphx {
...@@ -104,14 +104,24 @@ MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data, ...@@ -104,14 +104,24 @@ MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data,
return op.final(output_val, count); return op.final(output_val, count);
} }
template <class T, class U, class V, class W> template <class T1, class T2, class T3, class T4>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t) struct roalign_settings
{ {
const float roi_offset = ROIS_OFFSET; T1 roi_offset{};
const bool is_avg_pooling = IS_AVG_POOLING; T2 is_avg_pooling{};
const int64_t sampling_ratio = SAMPLING_RATIO; T3 sampling_ratio{};
const float spatial_scale = SPATIAL_SCALE; T4 spatial_scale{};
};
template <class... Ts>
constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{
auto index = make_index(); auto index = make_index();
const auto* x = x_t.data(); const auto* x = x_t.data();
const auto* rois = rois_t.data(); const auto* rois = rois_t.data();
...@@ -146,9 +156,10 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -146,9 +156,10 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
const auto* offset_rois = rois + (n * roi_column_num); const auto* offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * spatial_scale, array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * spatial_scale}; offset_rois[0] * s.spatial_scale};
array<float, 2> roi_ends = {offset_rois[3] * spatial_scale, offset_rois[2] * spatial_scale}; array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale,
offset_rois[2] * s.spatial_scale};
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
...@@ -161,11 +172,11 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -161,11 +172,11 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
bin_size[ii] = roi_size[ii] / out_dims[ii]; bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] = bin_grid_size[ii] =
(sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); (s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
out_ptr[i] = calc_pooling(offset_x, out_ptr[i] = calc_pooling(offset_x,
roi_starts, roi_starts,
...@@ -173,7 +184,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -173,7 +184,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
{ph, pw}, {ph, pw},
bin_grid_size, bin_grid_size,
in_dims, in_dims,
roi_offset, s.roi_offset,
avg_pool{}); avg_pool{});
} }
else else
...@@ -184,7 +195,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -184,7 +195,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
{ph, pw}, {ph, pw},
bin_grid_size, bin_grid_size,
in_dims, in_dims,
roi_offset, s.roi_offset,
max_pool{}); max_pool{});
} }
} }
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
...@@ -12,6 +12,8 @@ using index_int = std::uint32_t; ...@@ -12,6 +12,8 @@ using index_int = std::uint32_t;
template <class T, index_int N> template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N))); using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -13,7 +13,7 @@ constexpr auto vec_size(vec<T, N>) ...@@ -13,7 +13,7 @@ constexpr auto vec_size(vec<T, N>)
} }
template <class T> template <class T>
constexpr auto vec_size(T, ...) constexpr auto vec_size(T, ...) // NOLINT
{ {
return index_constant<0>{}; return index_constant<0>{};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment