"vscode:/vscode.git/clone" did not exist on "1a665a63b09a83ab06317f8acfe7e7f75037c5ab"
Commit 51597ed7 authored by Khalique's avatar Khalique
Browse files

fix tests and tf parser

parents 7bacd3ba bc80dee8
...@@ -73,7 +73,7 @@ __host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i) ...@@ -73,7 +73,7 @@ __host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i)
inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024) inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
......
...@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x) ...@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x)
template <class T, class Op> template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op) __device__ void dpp_reduce(T& in, Op op)
{ {
T out; T out{};
out = dpp_mov<dpp_row_shr(1)>(in); out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in); out = dpp_mov<dpp_row_shr(2)>(in);
......
...@@ -119,13 +119,13 @@ tensor_view<device_type<T>> device_cast(tensor_view<T> x) ...@@ -119,13 +119,13 @@ tensor_view<device_type<T>> device_cast(tensor_view<T> x)
} }
template <class T> template <class T>
T to_hip_type(T x) __device__ __host__ T to_hip_type(T x)
{ {
return x; return x;
} }
// Hip doens't support __fp16 // Hip doens't support __fp16
inline float to_hip_type(gpu_half x) { return x; } inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp> #include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -11,53 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,53 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, argument result, argument arg, int axis) void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto num_in_batch = lens[axis];
auto batch_lens = lens; auto batch_lens = lens;
std::size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
// each thread is for one item in the batch const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream,
auto batch_idx = batch.multi(i); batch_shape.elements() * block_size,
auto data_idx = batch_idx; block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
// get max using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_max = input[batch_idx]; type init = lowest();
for(std::size_t j = 1; j < num_in_batch; ++j)
{ auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input[data_idx])); return input[data_idx];
} });
for(std::size_t j = 0; j < num_in_batch; ++j) auto batch_sum =
{ block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] = input[data_idx] - batch_max; auto val = input[data_idx] - batch_max;
} return ::exp(to_hip_type(val));
});
auto batch_sum = ::exp(to_hip_type(output[batch_idx])); auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max;
for(std::size_t j = 1; j < num_in_batch; ++j)
{
data_idx[axis] = j;
batch_sum += ::exp(to_hip_type(output[data_idx]));
}
batch_sum = ::log(to_hip_type(batch_sum));
for(std::size_t j = 0; j < num_in_batch; ++j) idx.local_stride(batch_item_num, [&](auto j) {
{
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] -= batch_sum; output[data_idx] = input[data_idx] - log_batch_sum;
} });
}); });
}); });
return result;
} }
} // namespace device } // namespace device
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp> #include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -12,51 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,51 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, argument result, argument arg, int axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; std::size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
// each thread is for one item in the batch const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream,
auto batch_idx = batch.multi(i); batch_shape.elements() * block_size,
auto data_idx = batch_idx; block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
// get max using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_max = input[batch_idx]; type init = lowest();
for(std::size_t j = 1; j < n_dims; ++j)
{ auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input[data_idx]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] = exp(to_hip_type(input[data_idx] - batch_max)); return input[data_idx];
} });
auto batch_sum = output[batch_idx]; auto batch_sum =
for(std::size_t j = 1; j < n_dims; ++j) block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
{
data_idx[axis] = j; data_idx[axis] = j;
batch_sum += output[data_idx]; auto val = input[data_idx] - batch_max;
} return ::exp(to_hip_type(val));
});
for(std::size_t j = 0; j < n_dims; ++j) idx.local_stride(batch_item_num, [&](auto j) {
{
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] = output[data_idx] / batch_sum; auto val = input[data_idx] - batch_max;
} output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
}); });
}); });
return result;
} }
} // namespace device } // namespace device
......
...@@ -200,12 +200,33 @@ struct hip_add_relu ...@@ -200,12 +200,33 @@ struct hip_add_relu
} }
}; };
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
}
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
std::swap(*it, args.front());
}
struct find_add_relu struct find_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)( return match::name("gpu::relu")(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add"))); match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -213,6 +234,9 @@ struct find_add_relu ...@@ -213,6 +234,9 @@ struct find_add_relu
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
...@@ -226,8 +250,9 @@ struct find_triadd ...@@ -226,8 +250,9 @@ struct find_triadd
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"), return match::name("gpu::add")(match::either_arg(0, 1)(
match::any().bind("input"))); match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -236,14 +261,15 @@ struct find_triadd ...@@ -236,14 +261,15 @@ struct find_triadd
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); }; auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1) if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return; return;
args.insert(args.begin(), input_ins); args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one move_standard_front(args);
auto it = std::find_if(args.begin(), args.end(), is_broadcasted); move_broadcasted_back(args);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); p.replace_instruction(ins, hip_triadd{}, args);
} }
...@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const ...@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const
// clang-format off // clang-format off
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
// find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
// find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T>
struct val_index
{
T val;
int64_t index;
};
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{
return {v, -1};
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{
return {v, i};
}
struct argmax_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index < y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index < y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
auto batch_lens = lens;
size_t batch_item_num = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch.
const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
auto init = make_val_index<type>(op.init());
auto op_output =
block_reduce<max_block_size>(idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return make_val_index(input[arg_s.index(data_idx)], j);
});
if(idx.local == 0)
{
output[batch_s.index(batch_idx)] = op_output.index;
}
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMIN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, argument result, argument arg, int axis); void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, argument result, argument arg, int axis); void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -18,7 +18,8 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -18,7 +18,8 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
argument argument
hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
return device::logsoftmax(ctx.get_stream().get(), args[1], args[0], op.axis); device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
...@@ -102,6 +104,8 @@ struct miopen_apply ...@@ -102,6 +104,8 @@ struct miopen_apply
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
add_extend_op<hip_argmax, op::argmax>("argmax");
add_extend_op<hip_argmin, op::argmin>("argmin");
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
add_extend_op<hip_convert, op::convert>("convert"); add_extend_op<hip_convert, op::convert>("convert");
......
...@@ -39,7 +39,8 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -39,7 +39,8 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
return device::softmax(ctx.get_stream().get(), args[1], args[0], op.axis); device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
...@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off // clang-format off
return return
{ {
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
...@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{}, //dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -37,8 +37,49 @@ struct tf_parser ...@@ -37,8 +37,49 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::vector<size_t> bool should_transpose(instruction_ref ins) const
parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const {
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
return result;
}
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -120,60 +161,68 @@ struct tf_parser ...@@ -120,60 +161,68 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims); add_mem_op("ExpandDims", &tf_parser::parse_expanddims);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
} }
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f, bool transpose = true)
{ {
ops.emplace(name, f); if(transpose)
{
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
const std::vector<instruction_ref>& args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
} }
else
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{ {
ops.emplace(name, f); ops.emplace(name, f);
} }
}
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f, bool transpose = true)
{ {
add_op(name, [=](auto&&... xs) { add_op(name,
[=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); },
transpose);
} }
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) { add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
auto l0 = args[1]; // TODO
if(contains(attributes, "data_format")) // if(contains(attributes, "data_format"))
{ // {
if(is_nhwc) // if(is_nhwc)
{ // {
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]); // l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
} // }
} // }
return add_broadcastable_binary_op(args[0], l0, x); return add_broadcastable_binary_op(args[0], args[1], x);
}); },
false);
} }
template <class T> template <class T>
...@@ -212,20 +261,22 @@ struct tf_parser ...@@ -212,20 +261,22 @@ struct tf_parser
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1); return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1)));
} }
else else
{ {
return prog.add_instruction(x, {arg0, arg1}); return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
} }
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x, bool transpose = true)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); },
transpose);
} }
instruction_ref instruction_ref
...@@ -255,8 +306,7 @@ struct tf_parser ...@@ -255,8 +306,7 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = size_t axis = args[axis_idx]->eval().at<int64_t>();
parse_axis(args[axis_idx]->eval().at<int64_t>(), args[0]->get_shape().lens().size());
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -268,15 +318,7 @@ struct tf_parser ...@@ -268,15 +318,7 @@ struct tf_parser
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_tensor(attributes.at("value").tensor()); literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v); return prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
} }
instruction_ref instruction_ref
...@@ -307,21 +349,8 @@ struct tf_parser ...@@ -307,21 +349,8 @@ struct tf_parser
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
auto weights = to_kcxy(args[1]);
auto l0 = args[0]; auto l0 = args[0];
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
...@@ -371,8 +400,7 @@ struct tf_parser ...@@ -371,8 +400,7 @@ struct tf_parser
op.padding[1] = padding[1]; op.padding[1] = padding[1];
} }
} }
return prog.add_instruction(op, {l0, to_kcxy(args[1])});
return prog.add_instruction(op, {l0, weights});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
...@@ -395,6 +423,8 @@ struct tf_parser ...@@ -395,6 +423,8 @@ struct tf_parser
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
auto weights = to_kcxy(args[1]);
if(contains(attributes, "dilations")) if(contains(attributes, "dilations"))
{ {
std::vector<size_t> dilation; std::vector<size_t> dilation;
...@@ -408,20 +438,6 @@ struct tf_parser ...@@ -408,20 +438,6 @@ struct tf_parser
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
auto l0 = args[0]; auto l0 = args[0];
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
...@@ -469,8 +485,8 @@ struct tf_parser ...@@ -469,8 +485,8 @@ struct tf_parser
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights); auto new_weights =
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights); prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {l0, new_weights}); return prog.add_instruction(op, {l0, new_weights});
} }
...@@ -558,15 +574,14 @@ struct tf_parser ...@@ -558,15 +574,14 @@ struct tf_parser
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) + MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size)); " must be smaller than input size " + to_string(input_size));
} }
// check if input arg needs axis to be converted to NCHW
axis = parse_axis(axis, input_size);
std::transform( std::transform(
args.begin(), args.begin(),
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args); return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
} }
instruction_ref instruction_ref
...@@ -669,7 +684,7 @@ struct tf_parser ...@@ -669,7 +684,7 @@ struct tf_parser
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval(); auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
...@@ -701,7 +716,7 @@ struct tf_parser ...@@ -701,7 +716,7 @@ struct tf_parser
{ {
op::squeeze op; op::squeeze op;
auto input_dims = args[0]->get_shape().lens(); auto input_dims = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims", input_dims.size()); auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
...@@ -714,7 +729,7 @@ struct tf_parser ...@@ -714,7 +729,7 @@ struct tf_parser
} }
} }
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref parse_stridedslice(const std::string&, instruction_ref parse_stridedslice(const std::string&,
...@@ -725,11 +740,6 @@ struct tf_parser ...@@ -725,11 +740,6 @@ struct tf_parser
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
...@@ -748,10 +758,9 @@ struct tf_parser ...@@ -748,10 +758,9 @@ struct tf_parser
if(((shrink_axis_mask >> i) & bitwise_compare) == 1) if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
squeeze_axes = parse_axes(squeeze_axes, num_axes);
auto l0 = prog.add_instruction(op, args[0]); auto l0 = prog.add_instruction(op, make_contiguous(args[0]));
return prog.add_instruction(op::squeeze{squeeze_axes}, l0); return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
...@@ -768,7 +777,7 @@ struct tf_parser ...@@ -768,7 +777,7 @@ struct tf_parser
reorder_data(dims); reorder_data(dims);
} }
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s); instructions[name] = to_nhwc(prog.add_parameter(name, s));
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -1118,6 +1127,7 @@ program parse_tf(const std::string& name, bool is_nhwc) ...@@ -1118,6 +1127,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else #else
parser.parse_from(input); parser.parse_from(input);
#endif #endif
parser.to_nchw(std::prev(parser.prog.end()));
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -941,9 +941,6 @@ TEST_CASE(softmax_simple_test) ...@@ -941,9 +941,6 @@ TEST_CASE(softmax_simple_test)
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(2); std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
std::cout << v << "\t";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
...@@ -1138,6 +1135,114 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1138,6 +1135,114 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(argmax_test_0)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_1)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{1}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_2)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_0)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_1)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{1}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_2)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -592,13 +592,13 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -592,13 +592,13 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis>> struct test_softmax : verify_program<test_softmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -606,10 +606,38 @@ struct test_softmax : verify_program<test_softmax<Axis>> ...@@ -606,10 +606,38 @@ struct test_softmax : verify_program<test_softmax<Axis>>
} }
}; };
template struct test_softmax<0>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<2>; template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3>; template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -3344,32 +3372,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3344,32 +3372,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3377,7 +3386,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3377,7 +3386,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
} }
}; };
template struct test_logsoftmax_1<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment