"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4d0c7238e549caac773f7a9b3dcd25912696bb7b"
Commit 19ea0bf9 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-vector-reduce' into jit-vector-softmax

parents c4cd8b0a 85897b5a
...@@ -62,7 +62,7 @@ ...@@ -62,7 +62,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!wget -nc https://github.com/onnx/models/blob/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx" "!wget -nc https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx"
] ]
}, },
{ {
......
...@@ -23,7 +23,7 @@ unzip uncased_L-12_H-768_A-12.zip ...@@ -23,7 +23,7 @@ unzip uncased_L-12_H-768_A-12.zip
``` ```
5) Get BERT ONNX model (bertsquad-10.onnx): 5) Get BERT ONNX model (bertsquad-10.onnx):
``` ```
wget https://github.com/onnx/models/blob/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx wget https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
``` ```
6) Run the inference, it will compile and run the model on three questions and small data provided in `inputs.json`: 6) Run the inference, it will compile and run the model on three questions and small data provided in `inputs.json`:
``` ```
......
tensorflow==2.5.3 tensorflow==2.6.4
onnxruntime onnxruntime
tokenizers tokenizers
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -9,6 +10,9 @@ namespace onnx { ...@@ -9,6 +10,9 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean> struct parse_mean : op_parser<parse_mean>
{ {
const std::set<shape::type_t> float_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; } std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors /// Calculates the element-wise mean of n>=1 input tensors
...@@ -24,7 +28,8 @@ struct parse_mean : op_parser<parse_mean> ...@@ -24,7 +28,8 @@ struct parse_mean : op_parser<parse_mean>
auto divisor = info.add_literal( auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}}); migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
// TODO: Only divide when using floating-point if(contains(float_types, args[0]->get_shape().type()))
{
return std::accumulate(args.begin() + 1, return std::accumulate(args.begin() + 1,
args.end(), args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor), info.add_broadcastable_binary_op("div", args[0], divisor),
...@@ -36,6 +41,17 @@ struct parse_mean : op_parser<parse_mean> ...@@ -36,6 +41,17 @@ struct parse_mean : op_parser<parse_mean>
return info.add_broadcastable_binary_op("add", mean, div); return info.add_broadcastable_binary_op("add", mean, div);
}); });
} }
else
{
// Compute sum before division for integral types
auto sum = std::accumulate(
args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) {
return info.add_broadcastable_binary_op("add", accum, data_i);
});
return info.add_broadcastable_binary_op("div", sum, divisor);
}
}
}; };
} // namespace onnx } // namespace onnx
......
...@@ -16,11 +16,9 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -16,11 +16,9 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto bstride = s.strides()[n + 1]; auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1]; auto blen = s.lens()[n + 1];
if(astride == bstride * blen) if(astride == bstride * blen or alen == 1)
{
new_lens.push_back(alen * blen); new_lens.push_back(alen * blen);
} }
}
if(new_lens.size() != shapes.size()) if(new_lens.size() != shapes.size())
return false; return false;
std::size_t i = 0; std::size_t i = 0;
...@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true; return true;
} }
void reduce_dim1(std::vector<shape>& shapes)
{
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) {} while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1; return n + 1;
} }
void reduce_dim_all(std::vector<shape>& shapes) void reduce_dim_all(std::vector<shape>& shapes)
...@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes) ...@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0; std::size_t n = 0;
while(n < shapes.front().lens().size() - 1) while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n); n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
} }
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes) std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
...@@ -908,11 +908,6 @@ struct find_gemm_add ...@@ -908,11 +908,6 @@ struct find_gemm_add
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return not i->get_shape().standard();
}))
return;
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
...@@ -931,6 +926,53 @@ struct find_gemm_add ...@@ -931,6 +926,53 @@ struct find_gemm_add
} }
}; };
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise
{
auto matcher() const
{
return pointwise_name("add")(
match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
inputs.push_back(c_ins);
inputs.push_back(gemm_ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -967,7 +1009,11 @@ void fuse_ops::apply(module& m) const ...@@ -967,7 +1009,11 @@ void fuse_ops::apply(module& m) const
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{}); match::find_matches(m,
find_triadd_layernorm{},
find_gemm_add{},
find_gemm_pointwise{},
find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
} }
void blas_shape(const shape& s)
{
if(s.lens().size() < 2)
return;
if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(s.lens().size() < 3)
return;
shape batch_shape{s.type(),
{s.lens().begin(), s.lens().end() - 2},
{s.strides().begin(), s.strides().end() - 2}};
auto batch_shapes = reduce_dims({batch_shape});
if(batch_shapes.front().lens().size() != 1)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs) ...@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return f(xs..., nullptr, nullptr); return f(xs..., nullptr, nullptr);
} }
static bool is_transposed(const shape& s)
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static rocblas_int get_batch_stride(const argument& a)
{
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}
template <class T> template <class T>
void gemm_impl(context& ctx, void gemm_impl(context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -45,8 +74,8 @@ void gemm_impl(context& ctx, ...@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = is_transposed(args[0].get_shape());
bool transb = args[1].get_shape().transposed(); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
...@@ -142,6 +171,9 @@ void gemm_impl(context& ctx, ...@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
} }
else else
{ {
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -153,20 +185,20 @@ void gemm_impl(context& ctx, ...@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
k * n, b_stride,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
m * k, a_stride,
beta_v, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, c_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
......
...@@ -18,6 +18,8 @@ namespace gpu { ...@@ -18,6 +18,8 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
...@@ -50,13 +52,14 @@ struct rocblas_gemm ...@@ -50,13 +52,14 @@ struct rocblas_gemm
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides()); blas_shape(inputs[0]);
batch_not_transposed(inputs[1].strides()); blas_shape(inputs[1]);
// if gemm and add are fused // if gemm and add are fused
if(not float_equal(beta, 0)) if(in_shapes.size() > 2)
{ {
auto cmat_shape = in_shapes.back(); auto cmat_shape = in_shapes.back();
in_shapes.pop_back(); in_shapes.pop_back();
blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes); auto op_out_shape = op.compute_shape(in_shapes);
if(cmat_shape.lens() != op_out_shape.lens()) if(cmat_shape.lens() != op_out_shape.lens())
{ {
...@@ -71,6 +74,7 @@ struct rocblas_gemm ...@@ -71,6 +74,7 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape;
} }
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
...@@ -96,28 +100,6 @@ struct rocblas_gemm ...@@ -96,28 +100,6 @@ struct rocblas_gemm
return args.back(); return args.back();
} }
void batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{
MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) +
"} are transposed!");
}
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("GPU_GEMM: batch size {" + to_string_range(strides) +
"} is transposed!");
}
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -19,7 +19,7 @@ namespace gpu { ...@@ -19,7 +19,7 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__( static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp> #include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
......
...@@ -75,6 +75,7 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -75,6 +75,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs); auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params( options.set_launch_params(
v, v,
compute_global_for(ctx, compute_global_for(ctx,
......
...@@ -19,7 +19,6 @@ namespace gpu { ...@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const roialign_kernel = R"__migraphx__( static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp> #include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
......
...@@ -19,7 +19,6 @@ namespace gpu { ...@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__( static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp> #include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
......
...@@ -147,8 +147,8 @@ struct array ...@@ -147,8 +147,8 @@ struct array
constexpr array carry(array result) const constexpr array carry(array result) const
{ {
uint32_t overflow = 0; index_int overflow = 0;
for(std::ptrdiff_t i = result.size() - 1; i > 0; i--) for(diff_int i = result.size() - 1; i > 0; i--)
{ {
auto z = result[i] + overflow; auto z = result[i] + overflow;
// Reset overflow // Reset overflow
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct sum
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return x + y;
}
};
struct product
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
struct mean
{
size_t item_num = 1;
template <class T>
constexpr auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max_f
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return (x > y) ? x : y;
}
};
inline constexpr auto max = max_f{};
struct min_f
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return (x < y) ? x : y;
}
};
inline constexpr auto min = min_f{};
struct lowest
{
template <class T>
constexpr operator T() const
{
return std::numeric_limits<T>::lowest();
}
};
struct highest
{
template <class T>
constexpr operator T() const
{
return std::numeric_limits<T>::max();
}
};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
...@@ -137,7 +137,7 @@ constexpr auto by(F f) ...@@ -137,7 +137,7 @@ constexpr auto by(F f)
template <class F, class... Ts> template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs) constexpr void each_args(F f, Ts&&... xs)
{ {
swallow{(f(std::forward<Ts>(xs)), 0)...}; swallow{(f(static_cast<Ts&&>(xs)), 0)...};
} }
template <class F> template <class F>
......
...@@ -13,7 +13,7 @@ struct basic_iota_iterator ...@@ -13,7 +13,7 @@ struct basic_iota_iterator
F f; F f;
using difference_type = diff_int; using difference_type = diff_int;
using reference = decltype(f(std::declval<Iterator>())); using reference = decltype(f(declval<Iterator>()));
using value_type = remove_reference_t<reference>; using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>; using pointer = add_pointer_t<value_type>;
......
...@@ -3,14 +3,15 @@ ...@@ -3,14 +3,15 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp> #include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/array.hpp>
namespace migraphx { namespace migraphx {
struct max_pool struct max_pool
{ {
MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest{}; }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y)
...@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
return 0; return 0;
} }
xy[ii] = max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
low[ii] = xy[ii]; low[ii] = xy[ii];
high[ii] = low[ii] + 1; high[ii] = low[ii] + 1;
if(low[ii] >= dims[ii] - 1) if(low[ii] >= dims[ii] - 1)
...@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
for(index_int ii = 0; ii < roi_size.size(); ++ii) for(index_int ii = 0; ii < roi_size.size(); ++ii)
{ {
roi_size[ii] = roi_ends[ii] - roi_starts[ii]; roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f); roi_size[ii] = migraphx::max(roi_size[ii], 1.0f);
bin_size[ii] = roi_size[ii] / out_dims[ii]; bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] = bin_grid_size[ii] = (s.sampling_ratio > 0)
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); ? s.sampling_ratio
: migraphx::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
......
...@@ -11,7 +11,7 @@ template <class T> ...@@ -11,7 +11,7 @@ template <class T>
struct tensor_view_iterator_read struct tensor_view_iterator_read
{ {
T* view; T* view;
constexpr auto& operator()(std::size_t n) const constexpr auto& operator()(index_int n) const
{ {
MIGRAPHX_ASSERT(view != nullptr); MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n]; return (*view)[n];
......
...@@ -35,6 +35,21 @@ struct enable_if<true, T> ...@@ -35,6 +35,21 @@ struct enable_if<true, T>
template <bool B, class T = void> template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type; using enable_if_t = typename enable_if<B, T>::type;
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
template <bool B, class T, class F>
using conditional_t = typename conditional<B, T, F>::type;
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \ #define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \ template <class T> \
......
...@@ -80,7 +80,7 @@ __device__ __host__ auto as_vec(T* x) ...@@ -80,7 +80,7 @@ __device__ __host__ auto as_vec(T* x)
} }
template <class T, index_int N> template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>; using safe_vec = vec<conditional_t<is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment