Commit fba2cd0c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents 4d358059 8d5a2210
......@@ -12,7 +12,7 @@ namespace op {
struct argmax
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -26,7 +26,7 @@ struct argmax
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
//#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
//#include <migraphx/stringutils.hpp>
//#include <migraphx/literal.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
//#include <cmath>
//#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,7 +12,7 @@ namespace op {
struct argmin
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -31,7 +26,7 @@ struct argmin
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
......
......@@ -285,10 +285,10 @@ struct onnx_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
......@@ -300,7 +300,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
......@@ -312,10 +312,10 @@ struct onnx_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
......@@ -327,7 +327,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
......
......@@ -650,44 +650,6 @@ struct cpu_logsoftmax
}
};
struct cpu_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct cpu_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct cpu_apply
{
program* prog;
......@@ -707,8 +669,6 @@ struct cpu_apply
void init()
{
apply_map["argmax"] = extend_op<cpu_argmax, op::argmax>();
apply_map["argmin"] = extend_op<cpu_argmin, op::argmin>();
apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
......
......@@ -12,12 +12,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<type, argmax_op<type>>(argmax_op<type>{}, stream, result, arg, axis);
});
arg_op(argmax_op{}, stream, result, arg, axis);
}
} // namespace device
......
......@@ -12,12 +12,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<type, argmin_op<type>>(argmin_op<type>{}, stream, result, arg, axis);
});
arg_op(argmin_op{}, stream, result, arg, axis);
}
} // namespace device
......
......@@ -22,8 +22,20 @@ struct val_index
};
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{
return {v, -1};
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{
return {v, i};
}
struct argmax_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
......@@ -36,12 +48,12 @@ struct argmax_op
}
}
MIGRAPHX_DEVICE_CONSTEXPR T init() const { return lowest(); }
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
template <class T>
struct argmin_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
......@@ -54,11 +66,11 @@ struct argmin_op
}
}
MIGRAPHX_DEVICE_CONSTEXPR T init() const { return highest(); }
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class T, class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
......@@ -69,21 +81,21 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch.
const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
T init_val = op.init();
val_index<T> init = {init_val, -1};
auto init = make_val_index<type>(op.init());
auto op_output = block_reduce<max_block_size, Op, val_index<T>>(
idx, op, init, batch_item_num, [&](auto j) __device__ {
auto op_output =
block_reduce<max_block_size>(idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
T val = input[arg_s.index(data_idx)];
return val_index<T>{val, static_cast<int64_t>(j)};
return make_val_index(input[arg_s.index(data_idx)], j);
});
if(idx.local == 0)
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis);
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
......
......@@ -612,7 +612,7 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 4, 1026, 6}};
migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment