Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_int8_gemm_pack_a::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_a::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_a(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
shape hip_int8_gemm_pack_b::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_b::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_b(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -15,11 +15,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -15,11 +15,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument hip_logsoftmax::compute(context& ctx, argument
const shape& output_shape, hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
const std::vector<argument>& args) const
{ {
return device::logsoftmax(ctx.get_stream().get(), output_shape, args, op.axis); device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,9 +11,12 @@ ...@@ -11,9 +11,12 @@
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp> #include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp> #include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp> #include <migraphx/gpu/sigmoid.hpp>
...@@ -24,9 +27,12 @@ ...@@ -24,9 +27,12 @@
#include <migraphx/gpu/logsoftmax.hpp> #include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp> #include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp> #include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp> #include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp> #include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp> #include <migraphx/gpu/sin.hpp>
#include <migraphx/gpu/sign.hpp>
#include <migraphx/gpu/cos.hpp> #include <migraphx/gpu/cos.hpp>
#include <migraphx/gpu/tan.hpp> #include <migraphx/gpu/tan.hpp>
#include <migraphx/gpu/sinh.hpp> #include <migraphx/gpu/sinh.hpp>
...@@ -47,6 +53,14 @@ ...@@ -47,6 +53,14 @@
#include <migraphx/gpu/lrn.hpp> #include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -72,10 +86,8 @@ struct miopen_apply ...@@ -72,10 +86,8 @@ struct miopen_apply
void init() void init()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end())); this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_relu>("relu", make_relu);
add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op<miopen_abs>("abs", make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_simple_op<miopen_tanh>("tanh", make_tanh);
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu); add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu);
...@@ -83,31 +95,48 @@ struct miopen_apply ...@@ -83,31 +95,48 @@ struct miopen_apply
add_generic_op<hip_add>("add"); add_generic_op<hip_add>("add");
add_generic_op<hip_sub>("sub"); add_generic_op<hip_sub>("sub");
add_generic_op<hip_exp>("exp"); add_generic_op<hip_exp>("exp");
add_generic_op<hip_erf>("erf");
add_generic_op<hip_log>("log"); add_generic_op<hip_log>("log");
add_generic_op<hip_sin>("sin"); add_generic_op<hip_sin>("sin");
add_generic_op<hip_cos>("cos"); add_generic_op<hip_cos>("cos");
add_generic_op<hip_tan>("tan"); add_generic_op<hip_tan>("tan");
add_generic_op<hip_sinh>("sinh"); add_generic_op<hip_sinh>("sinh");
add_generic_op<hip_cosh>("cosh"); add_generic_op<hip_cosh>("cosh");
add_generic_op<hip_tanh>("tanh");
add_generic_op<hip_asin>("asin"); add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos"); add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan"); add_generic_op<hip_atan>("atan");
add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul"); add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max"); add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_round>("round");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_relu>("relu");
add_generic_op<hip_sign>("sign");
add_generic_op<hip_sigmoid>("sigmoid");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
add_extend_op<hip_argmax, op::argmax>("argmax");
add_extend_op<hip_argmin, op::argmin>("argmin");
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
add_extend_op<hip_convert, op::convert>("convert"); add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip"); add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
add_quant_convolution_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
} }
...@@ -154,6 +183,53 @@ struct miopen_apply ...@@ -154,6 +183,53 @@ struct miopen_apply
}); });
} }
template <class Op>
void add_gemm_op(std::string name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator());
auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs();
if((refs.size() == 2) or (refs.size() == 3 and refs.back()->outputs().size() > 1) or
(ins == last))
{
auto output = insert_allocation(ins, ins->get_shape());
if(refs.size() == 2)
{
beta = 0;
refs.push_back(output);
}
else
{
auto copy_out = prog->insert_instruction(ins, hip_copy{}, refs.back(), output);
refs.back() = copy_out;
refs.push_back(copy_out);
}
}
else
{
refs.push_back(refs.back());
}
return prog->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs);
});
}
void add_quant_convolution_op()
{
apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, conv, args[0], args[1], workspace, output);
});
}
void add_pooling_op() void add_pooling_op()
{ {
apply_map.emplace("pooling", [=](instruction_ref ins) { apply_map.emplace("pooling", [=](instruction_ref ins) {
......
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void pack_int8_args::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() == "gpu::quant_gemm")
{
auto inputs = ins->inputs();
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if(!transb)
{
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()});
auto output_b =
p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b});
instruction::replace_argument(ins, inputs[1], output_b);
}
if(transa)
{
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()});
auto output_a =
p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a});
instruction::replace_argument(ins, inputs[0], output_a);
}
}
else if(ins->name() == "gpu::quant_convolution")
{
auto inputs = ins->inputs();
auto packed_x =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())});
auto output_x =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())});
auto output_w =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w);
}
}
}
shape pack_int8_args::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
shape miopen_quant_convolution::compile(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], true);
auto w_desc = make_tensor(inputs[1], true);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
arg_vec4_w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
}
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs));
if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_relu::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_relu::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_sigmoid::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_sigmoid::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/softmax.hpp> #include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -30,6 +31,18 @@ argument miopen_softmax::compute(context& ctx, ...@@ -30,6 +31,18 @@ argument miopen_softmax::compute(context& ctx,
return args[1]; return args[1];
} }
shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#include <migraphx/gpu/tanh.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_tanh::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).packed();
return inputs.at(0);
}
argument miopen_tanh::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -13,14 +13,16 @@ ...@@ -13,14 +13,16 @@
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp> #include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
...@@ -36,23 +38,26 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -36,23 +38,26 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off // clang-format off
return return
{ {
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, rewrite_batchnorm{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, eliminate_common_subexpression{},
//dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
propagate_constant{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
...@@ -60,6 +65,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -60,6 +65,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{}, adjust_allocation{},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
...@@ -78,6 +85,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -78,6 +85,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
migraphx::context target::get_context() const { return context{}; } migraphx::context target::get_context() const { return context{}; }
argument target::copy_to(const argument& arg) const { return gpu::to_gpu(arg); }
argument target::copy_from(const argument& arg) const { return gpu::from_gpu(arg); }
argument target::allocate(const shape& s) const { return gpu::allocate_gpu(s); }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -45,7 +45,7 @@ void write_literals::apply(program& p) const ...@@ -45,7 +45,7 @@ void write_literals::apply(program& p) const
literal l = ins->get_literal(); literal l = ins->get_literal();
auto pre = p.add_literal(l); auto pre = p.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy{}, pre, alloc); p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
} }
else else
{ {
......
...@@ -21,6 +21,7 @@ set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) ...@@ -21,6 +21,7 @@ set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraphx_tf tf.cpp) add_library(migraphx_tf tf.cpp)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${PROJECT_VERSION})
rocm_clang_tidy_check(migraphx_tf) rocm_clang_tidy_check(migraphx_tf)
target_link_libraries(migraphx_tf PRIVATE tf-proto) target_link_libraries(migraphx_tf PRIVATE tf-proto)
target_link_libraries(migraphx_tf PUBLIC migraphx) target_link_libraries(migraphx_tf PUBLIC migraphx)
...@@ -31,7 +32,7 @@ rocm_install_targets( ...@@ -31,7 +32,7 @@ rocm_install_targets(
add_executable(read_tf read_tf.cpp) add_executable(read_tf read_tf.cpp)
rocm_clang_tidy_check(read_tf) rocm_clang_tidy_check(read_tf)
target_link_libraries(read_tf migraphx_tf) target_link_libraries(read_tf migraphx_tf migraphx_cpu)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_executable(verify_tf verify_tf.cpp) add_executable(verify_tf verify_tf.cpp)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/pad_calc.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -24,8 +25,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -24,8 +25,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser struct tf_parser
{ {
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>; using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>; using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
...@@ -36,7 +36,50 @@ struct tf_parser ...@@ -36,7 +36,50 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const bool should_transpose(instruction_ref ins) const
{
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
return result;
}
std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -44,14 +87,14 @@ struct tf_parser ...@@ -44,14 +87,14 @@ struct tf_parser
if(is_nhwc) if(is_nhwc)
{ {
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) { std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis); return parse_axis(axis, num_dims);
}); });
} }
return axes; return axes;
} }
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
...@@ -59,7 +102,7 @@ struct tf_parser ...@@ -59,7 +102,7 @@ struct tf_parser
std::transform(axes.begin(), std::transform(axes.begin(),
axes.end(), axes.end(),
std::back_inserter(new_axes), std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); }); [&](size_t axis) { return parse_axis(axis, num_dims); });
return new_axes; return new_axes;
} }
return axes; return axes;
...@@ -74,17 +117,17 @@ struct tf_parser ...@@ -74,17 +117,17 @@ struct tf_parser
std::vector<T> new_data(prev_data.size()); std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++) for(size_t i = 0; i < new_data.size(); i++)
{ {
auto new_idx = parse_axis(i); auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i); new_data.at(new_idx) = prev_data.at(i);
} }
prev_data = new_data; prev_data = new_data;
} }
template <class T> template <class T>
T parse_axis(const T& dim) const T parse_axis(const T& dim, const size_t num_dims) const
{ {
T new_dim = dim; T new_dim = dim;
if(is_nhwc) if(is_nhwc and num_dims >= 4)
{ {
switch(dim) switch(dim)
{ {
...@@ -105,70 +148,109 @@ struct tf_parser ...@@ -105,70 +148,109 @@ struct tf_parser
return axes; return axes;
} }
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser() tf_parser()
{ {
add_generic_op("All", op::identity{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Pow", op::pow{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat); add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul); add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("Pack", &tf_parser::parse_pack); add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
} add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
} }
// Multi output op
template <class F> template <class F>
void add_multi_op(std::string name, F f) void add_op(std::string name, F f, bool transpose = true)
{ {
ops.emplace(name, f); if(transpose)
{
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
const std::vector<instruction_ref>& args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
}
else
{
ops.emplace(name, f);
}
} }
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f, bool transpose = true)
{ {
add_op(name, [=](auto&&... xs) { add_op(name,
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); [=](auto&&... xs) {
}); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
},
transpose);
} }
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) { add_op(name,
if(args.size() != 2) [this, x](const attribute_map&, std::vector<instruction_ref> args) {
MIGRAPHX_THROW("binary operators should have 2 operands"); if(args.size() != 2)
auto l0 = args[1]; MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "data_format")) // TODO
{ // if(contains(attributes, "data_format"))
if(is_nhwc) // {
{ // if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]); // {
} // l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
} // }
return add_broadcastable_binary_op(args[0], l0, x); // }
}); return add_broadcastable_binary_op(args[0], args[1], x);
},
false);
} }
template <class T> template <class T>
...@@ -207,20 +289,22 @@ struct tf_parser ...@@ -207,20 +289,22 @@ struct tf_parser
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1); return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1)));
} }
else else
{ {
return prog.add_instruction(x, {arg0, arg1}); return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
} }
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x, bool transpose = true)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name,
return prog.add_instruction(x, args); [this, x](const attribute_map&, std::vector<instruction_ref> args) {
}); return prog.add_instruction(x, args);
},
transpose);
} }
instruction_ref instruction_ref
...@@ -245,12 +329,19 @@ struct tf_parser ...@@ -245,12 +329,19 @@ struct tf_parser
return prog.add_instruction(op::add{}, args[0], l0); return prog.add_instruction(op::add{}, args[0], l0);
} }
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return prog.add_instruction(op::convert{type}, std::move(args));
}
instruction_ref instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>()); size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -261,45 +352,14 @@ struct tf_parser ...@@ -261,45 +352,14 @@ struct tf_parser
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_tensor(attributes.at("value").tensor()); literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v); return prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
} }
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; op::convolution op;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -324,22 +384,58 @@ struct tf_parser ...@@ -324,22 +384,58 @@ struct tf_parser
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param") auto weights = to_kcxy(args[1]);
auto l0 = args[0];
if(contains(attributes, "padding"))
{ {
if(is_nhwc) const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{ {
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]); op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2];
size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding}, l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
} }
else else if(pad_mode.find("VALID") != std::string::npos)
{ {
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]); op.padding_mode = op::padding_mode_t::valid;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
} }
} }
return prog.add_instruction(op, {l0, to_kcxy(args[1])});
return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
...@@ -349,14 +445,7 @@ struct tf_parser ...@@ -349,14 +445,7 @@ struct tf_parser
op::convolution op; op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1]; size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels; op.group = num_channels;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -369,17 +458,54 @@ struct tf_parser ...@@ -369,17 +458,54 @@ struct tf_parser
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
auto weights = args[1];
// check if weights are from a constant auto weights = to_kcxy(args[1]);
if(weights->name() != "@param") if(contains(attributes, "dilations"))
{ {
if(is_nhwc) std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{ {
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]); MIGRAPHX_THROW("dilation should have 4 values");
} }
else op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto l0 = args[0];
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{ {
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]); op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2];
size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding}, l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
} }
} }
...@@ -394,10 +520,37 @@ struct tf_parser ...@@ -394,10 +520,37 @@ struct tf_parser
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights); auto new_weights =
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights); prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {args[0], new_weights}); return prog.add_instruction(op, {l0, new_weights});
}
instruction_ref
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return prog.add_instruction(op::reshape{new_dims}, args[0]);
}
instruction_ref
parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
int axis = args[2]->eval().at<int32_t>();
op::gather op{axis};
return prog.add_instruction(op, {args[0], args[1]});
} }
instruction_ref instruction_ref
...@@ -412,7 +565,16 @@ struct tf_parser ...@@ -412,7 +565,16 @@ struct tf_parser
} }
if(contains(attributes, "transpose_b")) if(contains(attributes, "transpose_b"))
{ {
transb = attributes.at("transpose_a").b(); transb = attributes.at("transpose_b").b();
}
if(contains(attributes, "adj_x"))
{
transa = attributes.at("adj_x").b();
}
if(contains(attributes, "adj_y"))
{
transb = attributes.at("adj_y").b();
} }
std::vector<int64_t> perm(args[0]->get_shape().lens().size()); std::vector<int64_t> perm(args[0]->get_shape().lens().size());
...@@ -429,23 +591,44 @@ struct tf_parser ...@@ -429,23 +591,44 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); if(keep_dims)
if(axes == hw_axes and lens.size() == 4) {
return prog.add_instruction(op::reduce_mean{axes}, args[0]);
}
else
{ {
op::pooling op{"average"}; auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]);
op.lengths[0] = lens[2]; return prog.add_instruction(op::squeeze{axes}, ins);
op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); }
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
} }
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
...@@ -463,16 +646,14 @@ struct tf_parser ...@@ -463,16 +646,14 @@ struct tf_parser
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) + MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size)); " must be smaller than input size " + to_string(input_size));
} }
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform( std::transform(
args.begin(), args.begin(),
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args); return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
} }
instruction_ref instruction_ref
...@@ -508,18 +689,6 @@ struct tf_parser ...@@ -508,18 +689,6 @@ struct tf_parser
{ {
op::pooling op{starts_with(name, "Max") ? "max" : "average"}; op::pooling op{starts_with(name, "Max") ? "max" : "average"};
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -544,7 +713,39 @@ struct tf_parser ...@@ -544,7 +713,39 @@ struct tf_parser
op.lengths[0] = ksize[2]; op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3]; op.lengths[1] = ksize[3];
} }
return prog.add_instruction(op, args[0]);
auto l0 = args[0];
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2];
size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_w, op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
return prog.add_instruction(op, l0);
} }
instruction_ref instruction_ref
...@@ -555,7 +756,7 @@ struct tf_parser ...@@ -555,7 +756,7 @@ struct tf_parser
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval(); auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
...@@ -572,13 +773,46 @@ struct tf_parser ...@@ -572,13 +773,46 @@ struct tf_parser
} }
instruction_ref instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto size = args[2]->eval().get<int32_t>().to_vector();
auto axes = args[0]->get_shape().lens();
size_t num_axes = axes.size();
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(num_axes);
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
for(size_t i = 0; i < num_axes; i++)
{
if(size[i] == -1)
op.ends[i] = axes[i];
else
op.ends[i] = starts[i] + size[i];
}
return prog.add_instruction(op, make_contiguous(args[0]));
}
// template to facilitate the logsoftmax later
template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{ {
auto dims = args.front()->get_shape().lens(); int axis = -1;
auto r = auto num_dims = args[0]->get_shape().lens().size();
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front()); if(contains(attributes, "axis"))
auto s = prog.add_instruction(op::softmax{}, r); {
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s); axis = static_cast<int>(attributes.at("axis").i());
}
if(axis < 0)
{
axis += num_dims;
}
return prog.add_instruction(Op{axis}, make_contiguous(args[0]));
} }
instruction_ref parse_squeeze(const std::string&, instruction_ref parse_squeeze(const std::string&,
...@@ -586,20 +820,21 @@ struct tf_parser ...@@ -586,20 +820,21 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto axes = parse_axes(attributes, "squeeze_dims"); auto input_dims = args[0]->get_shape().lens();
auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for(size_t i = 0; i < args0_dims.size(); i++) for(size_t i = 0; i < input_dims.size(); i++)
{ {
if(args0_dims.at(i) == 1) if(input_dims.at(i) == 1)
{ {
op.axes.push_back(i); op.axes.push_back(i);
} }
} }
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref parse_stridedslice(const std::string&, instruction_ref parse_stridedslice(const std::string&,
...@@ -607,39 +842,68 @@ struct tf_parser ...@@ -607,39 +842,68 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::slice op; op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); auto l0 = args[0];
if(num_axes >= 4) size_t num_axes = l0->get_shape().lens().size();
{ std::vector<size_t> axes = l0->get_shape().lens();
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0; uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1; uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes; std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask")) if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i()); shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l1 = prog.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
// the LSB corresponds to axis 0 when determining which axes to squeeze // the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1) if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]); return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0); }
instruction_ref
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front());
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
...@@ -656,7 +920,7 @@ struct tf_parser ...@@ -656,7 +920,7 @@ struct tf_parser
reorder_data(dims); reorder_data(dims);
} }
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s); instructions[name] = to_nhwc(prog.add_parameter(name, s));
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -669,10 +933,16 @@ struct tf_parser ...@@ -669,10 +933,16 @@ struct tf_parser
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); auto&& iname = get_name(nodes.at(input));
...@@ -732,72 +1002,56 @@ struct tf_parser ...@@ -732,72 +1002,56 @@ struct tf_parser
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t) switch(t)
{ {
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break; case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break; case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break; case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break; case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break; case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING: case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64: case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8: case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8: case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32: case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16: case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16: case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16: case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128: case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE: case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT: case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
// tf pb should not use these types // tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF: break; case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF: break; case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF: break; case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF: break; case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF: break; case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF: break; case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF: break; case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: break; case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF: break; case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF: break; case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF: break; case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF: break; case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF: break; case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: break; case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF: break; case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF: break; case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF: break; case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: break; case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF: break; case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF: break; case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF: break; case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF: break; case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF: break; case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
} }
return shape_type; return shape_type;
...@@ -812,61 +1066,59 @@ struct tf_parser ...@@ -812,61 +1066,59 @@ struct tf_parser
const std::string& s = t.tensor_content(); const std::string& s = t.tensor_content();
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()}; return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int16_type, dims}, s.data()}; return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()}; return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_QINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_QINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error(""); throw std::runtime_error("");
} }
...@@ -874,11 +1126,9 @@ struct tf_parser ...@@ -874,11 +1126,9 @@ struct tf_parser
} }
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return create_literal( return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
...@@ -890,7 +1140,6 @@ struct tf_parser ...@@ -890,7 +1140,6 @@ struct tf_parser
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return create_literal( return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF:
...@@ -906,43 +1155,45 @@ struct tf_parser ...@@ -906,43 +1155,45 @@ struct tf_parser
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_QINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_QINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error(""); throw std::runtime_error("");
} }
...@@ -1006,6 +1257,7 @@ program parse_tf(const std::string& name, bool is_nhwc) ...@@ -1006,6 +1257,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else #else
parser.parse_from(input); parser.parse_from(input);
#endif #endif
parser.to_nchw(std::prev(parser.prog.end()));
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -119,7 +119,7 @@ foreach(ONNX_TEST ${ONNX_TESTS}) ...@@ -119,7 +119,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST}) add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx) target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_cpu)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
...@@ -129,7 +129,7 @@ endforeach() ...@@ -129,7 +129,7 @@ endforeach()
# tf test # tf test
add_executable(test_tf tf/tf_test.cpp) add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf) rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf) target_link_libraries(test_tf migraphx_tf migraphx_cpu)
target_include_directories(test_tf PUBLIC include) target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf) add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_dependencies(tests test_tf) add_dependencies(tests test_tf)
......
...@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2) ...@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2)
} }
} }
TEST_CASE(quant_dot_2args_multi4)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {112, 118, 124, 130, 136, 142, 148, 154, 304, 326, 348,
370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686,
724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704,
736, 768, 592, 628, 664, 700, 736, 772, 808, 844};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822,
974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708,
836, 964, 74, 218, 362, 506, 650, 794, 938, 1082};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_2args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {
210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(
migraphx::op::quant_dot{
2,
},
l1,
tl2);
std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2);
std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_batch)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 2, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 4, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 2, 7}};
std::vector<int8_t> data1(4 * 2 * 4);
std::vector<int8_t> data2(4 * 4 * 7);
std::vector<int> data3(4 * 2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 6, 4}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 3, 6}};
std::vector<int8_t> data1(48);
std::vector<int8_t> data2(96);
std::vector<int> data3(72);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -527,6 +528,51 @@ TEST_CASE(exp_test) ...@@ -527,6 +528,51 @@ TEST_CASE(exp_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(erf_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {4}};
auto l =
p.add_literal(migraphx::literal{s, {0.73785057, 1.58165966, -0.43597795, -0.01677432}});
p.add_instruction(migraphx::op::erf{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.70327317, 0.97470088, -0.46247893, -0.01892602};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}});
p.add_instruction(migraphx::op::sqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sign_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0}});
p.add_instruction(migraphx::op::sign{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -541,6 +587,21 @@ TEST_CASE(log_test) ...@@ -541,6 +587,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(pow_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::pow{}, b, e);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sin_test) TEST_CASE(sin_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -929,6 +990,21 @@ TEST_CASE(maxpool_test) ...@@ -929,6 +990,21 @@ TEST_CASE(maxpool_test)
EXPECT(migraphx::verify_range(results_vector, c)); EXPECT(migraphx::verify_range(results_vector, c));
} }
TEST_CASE(softmax_simple_test)
{
migraphx::program p;
std::vector<float> a = {0.25, 0.75};
std::vector<float> s = {0.377541, 0.622459};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
p.add_instruction(migraphx::op::softmax{1}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1002,14 +1078,13 @@ TEST_CASE(logsoftmax_test_axis_0) ...@@ -1002,14 +1078,13 @@ TEST_CASE(logsoftmax_test_axis_0)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-2.71138556, -5.85030702, -3.74063578, -4.22915517, -6.15821977, -5.96072346, -3.57208097, -0.135261, -2.843968, -0.659995, -0.488413, -1.051857, -2.812936, -0.250956, -0.353985,
-5.78313166, -5.51435497, -3.67224195, -3.88393048, -2.57061599, -5.54431083, -6.27880025, -1.155980, -0.603651, -0.211969, -0.175371, -1.336552, -3.885010, -1.871544, -0.837083,
-5.1878749, -6.1318955, -5.29178545, -4.22537886, -3.75693516, -7.07047099, -4.45763333, -0.887745, -0.433338, -1.158864, -4.911197, -1.147972, -0.666711, -0.996874, -0.981418,
-4.66281846, -6.18290503, -4.11886536, -6.17408292, -4.18030052, -4.64570814, -4.64354473, -0.851145, -0.853988, -0.858112, -2.067420, -0.059956, -0.727436, -0.950881, -0.429689,
-3.06629525, -3.80807681, -4.69162374, -5.53605222, -3.20969275, -4.82645674, -6.63942356, -0.061906, -1.505332, -1.210277, -0.377970, -0.791448, -1.655428, -1.827253, -0.304828,
-4.73634471, -3.86003866, -5.32738981, -4.22249802, -4.51258693, -2.41455206, -3.48343199, -0.020762, -0.167101, -0.567346, -0.530319, -1.045094, -0.376648, -0.007391, -0.381670,
-5.86215889, -4.93435935, -4.83713408, -2.97471885, -2.16666459, -3.69133151, -4.71640968, -0.720302, -0.460499, -0.469651, -0.556740, -0.554628, -0.551582};
-5.64652924, -3.60709827, -5.87967748, -3.8809403, -4.33917815};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
...@@ -1036,14 +1111,13 @@ TEST_CASE(logsoftmax_test_axis_1) ...@@ -1036,14 +1111,13 @@ TEST_CASE(logsoftmax_test_axis_1)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-1.77931988, -4.91824134, -2.80857010, -3.29708949, -5.22615409, -5.02865778, -2.64001529, -0.550468, -2.132973, -1.549746, -0.650533, -1.051529, -2.248570, -0.141017, -2.028357,
-4.85106598, -4.58228929, -2.74017627, -2.95186480, -1.63855031, -4.61224515, -5.34673457, -1.947730, -1.511324, -0.166597, -0.379726, -1.965689, -1.172109, -1.475721, -2.700831,
-4.25580922, -5.19982982, -4.35971977, -3.29331318, -2.82486948, -6.13840531, -3.52556765, -1.537011, -0.658754, -1.596017, -3.353137, -2.266743, -1.084197, -1.076214, -0.406712,
-3.73075278, -5.25083935, -3.18679968, -5.24201724, -3.24823484, -3.71364246, -4.14309917, -2.743019, -0.425526, -1.079083, -2.139486, -1.270584, -1.024088, -1.154231, -3.201762,
-2.56584969, -3.30763125, -4.19117818, -5.03560666, -2.70924719, -4.32601118, -6.13897800, -0.888957, -0.532855, -3.103583, -1.221339, -1.355980, -3.531678, -1.438510, -0.975194,
-4.23589915, -3.35959310, -4.82694425, -3.72205246, -4.01214137, -1.91410650, -2.98298643, -0.080261, -1.162697, -1.568557, -1.398519, -1.322129, -0.470660, -0.370953, -0.907343,
-5.36171333, -4.43391379, -4.33668852, -2.47427329, -1.66621903, -3.19088595, -4.21596412, -1.179017, -3.312239, -1.286363, -1.586076, -0.345100, -0.824173};
-5.14608368, -3.10665271, -5.37923192, -3.38049474, -3.83873259};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
...@@ -1070,14 +1144,13 @@ TEST_CASE(logsoftmax_test_axis_2) ...@@ -1070,14 +1144,13 @@ TEST_CASE(logsoftmax_test_axis_2)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-0.79763715, -3.93655861, -1.82688737, -2.31540676, -4.24447136, -4.04697505, -1.65833256, -0.495957, -1.031212, -0.245531, -2.013726, -1.339125, -2.465619, -1.356652, -0.964037,
-3.86938325, -3.60060656, -1.81223672, -2.02392525, -0.71061076, -3.68430560, -4.41879502, -2.019250, -0.214522, -0.289569, -0.234392, -2.086591, -2.684439, -2.851651, -2.674176,
-3.32786967, -4.27189027, -3.43178022, -2.36537363, -1.35498658, -4.66852241, -2.05568475, -1.697424, -1.889155, -0.401029, -3.064586, -1.173030, -1.306912, -2.177020, -0.834262,
-2.26086988, -3.78095645, -1.71691678, -3.77213434, -1.77835194, -2.24375956, -2.74631770, -2.818177, -0.174415, -1.361105, -1.024571, -0.106766, -1.167645, -1.072650, -2.576522,
-1.16906822, -1.91084978, -2.79439671, -3.63882519, -1.31246572, -2.92922971, -4.74219653, -0.569261, -1.207483, -3.679894, -2.095913, -0.504264, -3.039291, -1.290559, -1.156812,
-2.83911768, -2.19738500, -3.66473615, -2.55984436, -2.84993327, -0.75189840, -1.82077833, -0.126453, -0.551493, -2.506384, -2.646261, -1.905195, -0.206994, -0.191369, -0.959754,
-4.19950523, -3.27170569, -3.17448042, -1.65286841, -0.84481415, -2.36948107, -3.39455924, -1.948685, -3.671233, -0.875521, -3.111952, -1.905644, -1.6076011};
-4.32467880, -2.28524783, -4.55782704, -2.55908986, -3.01732771};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
...@@ -1104,14 +1177,13 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1104,14 +1177,13 @@ TEST_CASE(logsoftmax_test_axis_3)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-0.33690375, -3.47582521, -1.36615397, -0.27936556, -2.20843016, -2.01093385, -0.22551114, -0.336904, -3.475825, -1.366154, -0.279366, -2.208430, -2.010934, -0.225511, -2.436562,
-2.43656183, -2.16778514, -1.57241522, -1.78410375, -0.47078926, -1.06745881, -1.80194823, -2.167785, -1.572415, -1.784104, -0.470789, -1.067459, -1.801948, -0.711023, -2.307197,
-0.71102288, -2.30719726, -1.46708721, -0.40068062, -0.42698261, -3.74051844, -1.12768078, -1.467087, -0.400681, -0.426983, -3.740518, -1.127681, -1.078919, -2.599005, -0.534965,
-1.07891856, -2.59900513, -0.53496546, -2.56139951, -0.56761711, -1.03302473, -2.09771276, -2.561400, -0.567617, -1.033025, -2.097713, -0.520463, -1.262245, -1.763230, -2.607658,
-0.52046328, -1.26224484, -1.76322959, -2.60765807, -0.28129860, -0.81424303, -2.62720985, -0.281299, -0.814243, -2.627210, -0.724131, -0.655704, -2.123055, -1.018163, -2.480634,
-0.72413100, -0.65570381, -2.12305496, -1.01816317, -2.48063402, -0.38259915, -1.45147908, -0.382599, -1.451479, -1.843102, -0.915303, -0.818078, -1.316929, -0.508875, -2.033541,
-1.84310238, -0.91530284, -0.81807757, -1.31692881, -0.50887455, -2.03354147, -1.48767160, -1.487672, -2.417791, -0.378360, -2.568531, -0.569794, -1.028032};
-2.41779116, -0.37836019, -2.56853147, -0.56979429, -1.02803214};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
...@@ -1124,38 +1196,112 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1124,38 +1196,112 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_4) TEST_CASE(argmax_test_0)
{ {
migraphx::program p; migraphx::program p;
std::vector<float> a = { std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, auto dl = p.add_literal(migraphx::literal{data_shape, data});
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, p.add_instruction(migraphx::op::argmax{0}, dl);
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, p.compile(migraphx::cpu::target{});
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> s = {0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, EXPECT(migraphx::verify_range(result_vec, res_gold));
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, }
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; TEST_CASE(argmax_test_1)
auto al = p.add_literal(migraphx::literal{a_shape, a}); {
int axis = 4; migraphx::program p;
p.add_instruction(migraphx::op::logsoftmax{axis}, al); std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{1}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector; std::vector<int64_t> result_vec;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_2)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_0)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_1)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{1}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_2)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
} }
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
...@@ -1338,6 +1484,107 @@ TEST_CASE(conv2d_padding_stride_test) ...@@ -1338,6 +1484,107 @@ TEST_CASE(conv2d_padding_stride_test)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(quant_conv2d_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893,
51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_stride_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {4521,
7014,
7830,
11952,
10515,
16734,
19737,
30906,
13161,
19542,
19494,
28800,
34707,
52590,
54729,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(transpose_test) TEST_CASE(transpose_test)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}};
...@@ -1574,7 +1821,7 @@ TEST_CASE(fp32_fp16_test) ...@@ -1574,7 +1821,7 @@ TEST_CASE(fp32_fp16_test)
auto test_case = [&](std::vector<std::string>&& op_names) { auto test_case = [&](std::vector<std::string>&& op_names) {
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0}; std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program(); auto p = create_program();
migraphx::quantize(p, op_names); migraphx::quantize_fp16(p, op_names);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> res; std::vector<float> res;
...@@ -1603,4 +1850,238 @@ TEST_CASE(clip_test) ...@@ -1603,4 +1850,238 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reduce_sum_axis0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{15, 18, 21, 24};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{4, 6, 12, 14, 20, 22};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{3, 7, 11, 15, 19, 23};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{33, 45};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{10, 26, 42};
EXPECT(results_vector == gold);
}
TEST_CASE(rsqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
p.add_instruction(migraphx::op::rsqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(round_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
p.add_instruction(migraphx::op::round{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
{
std::cout << v << "\t";
}
std::cout << std::endl;
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(op_capture)
{
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = p.add_literal(s1, d1);
auto p2 = p.add_literal(s1, d1);
auto pb = p.add_literal(s2, d2);
auto pc = p.add_literal(s2, d2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
migraphx::program capture_p = p;
migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(capture_p, t, {"dot"});
p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{});
auto cap_res = capture_p.eval({});
auto res = p.eval({});
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -9,7 +9,7 @@ struct cse_target ...@@ -9,7 +9,7 @@ struct cse_target
std::string name() const { return "dce"; } std::string name() const { return "dce"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraphx::common_subexpression_elimination{}, migraphx::dead_code_elimination{}}; return {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}};
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
......
...@@ -22,7 +22,7 @@ struct eliminate_contiguous_target ...@@ -22,7 +22,7 @@ struct eliminate_contiguous_target
TEST_CASE(standard_op) TEST_CASE(standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c); p.add_instruction(pass_standard_op{}, c);
...@@ -31,18 +31,40 @@ TEST_CASE(standard_op) ...@@ -31,18 +31,40 @@ TEST_CASE(standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(non_standard_op) TEST_CASE(standard_op_const)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(non_standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(non_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(transpose_gemm) TEST_CASE(transpose_gemm)
{ {
migraphx::program p; migraphx::program p;
...@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm) ...@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm)
TEST_CASE(transpose_standard_op) TEST_CASE(transpose_standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c); auto sn = p.add_instruction(migraphx::op::sin{}, c);
...@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op) ...@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(transpose_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 3);
}
TEST_CASE(no_packed_unary_op) TEST_CASE(no_packed_unary_op)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -83,23 +83,4 @@ TEST_CASE(rewrite_test_asymmetric) ...@@ -83,23 +83,4 @@ TEST_CASE(rewrite_test_asymmetric)
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
} }
TEST_CASE(rewrite_test_same_padding)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
create_conv(padded_img, channels, p, migraphx::op::padding_mode_t::same);
p.compile(eliminate_pad_target{});
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -243,6 +243,43 @@ struct test_exp : verify_program<test_exp> ...@@ -243,6 +243,43 @@ struct test_exp : verify_program<test_exp>
} }
}; };
struct test_erf : verify_program<test_erf>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::erf{}, param);
return p;
}
};
struct test_sqrt : verify_program<test_sqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
auto param_abs = p.add_instruction(migraphx::op::abs{}, param);
p.add_instruction(migraphx::op::sqrt{}, param_abs);
return p;
}
};
struct test_sign : verify_program<test_sign>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::sign{}, param);
return p;
}
};
struct test_log : verify_program<test_log> struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -255,6 +292,20 @@ struct test_log : verify_program<test_log> ...@@ -255,6 +292,20 @@ struct test_log : verify_program<test_log>
} }
}; };
struct test_pow : verify_program<test_pow>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e);
return p;
}
};
struct test_sin : verify_program<test_sin> struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -451,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -451,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
} }
}; };
struct test_mul_add : verify_program<test_mul_add>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto a = p.add_parameter("a", bs);
auto b = p.add_parameter("b", bs);
auto ab = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, a);
auto bb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, b);
auto mul = p.add_instruction(migraphx::op::mul{}, x, ab);
p.add_instruction(migraphx::op::add{}, mul, bb);
return p;
}
};
struct test_add_broadcast : verify_program<test_add_broadcast> struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -569,13 +638,45 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -569,13 +638,45 @@ struct test_sub2 : verify_program<test_sub2>
} }
}; };
struct test_softmax : verify_program<test_softmax> struct test_div : verify_program<test_div>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 4, 2}}); migraphx::shape s{migraphx::shape::float_type, {3}};
p.add_instruction(migraphx::op::softmax{}, x); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto diff = p.add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, z);
return p;
}
};
struct test_div2 : verify_program<test_div2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, zb);
return p;
}
};
struct test_softmax1 : verify_program<test_softmax1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}});
p.add_instruction(migraphx::op::softmax{0}, x);
return p; return p;
} }
}; };
...@@ -592,6 +693,53 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -592,6 +693,53 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
return p;
}
};
template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -679,6 +827,77 @@ struct test_add_relu : verify_program<test_add_relu> ...@@ -679,6 +827,77 @@ struct test_add_relu : verify_program<test_add_relu>
} }
}; };
struct test_add_sigmoid : verify_program<test_add_sigmoid>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::sigmoid{}, add);
return p;
}
};
struct test_add_tanh : verify_program<test_add_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::tanh{}, add);
return p;
}
};
struct test_triadd_relu : verify_program<test_triadd_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
auto triadd = p.add_instruction(migraphx::op::add{}, sum, z);
p.add_instruction(migraphx::op::relu{}, triadd);
return p;
}
};
struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
auto triadd = p.add_instruction(migraphx::op::add{}, sum, z);
p.add_instruction(migraphx::op::sigmoid{}, triadd);
return p;
}
};
struct test_triadd_tanh : verify_program<test_triadd_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
auto triadd = p.add_instruction(migraphx::op::add{}, sum, z);
p.add_instruction(migraphx::op::tanh{}, triadd);
return p;
}
};
struct test_sigmoid : verify_program<test_sigmoid> struct test_sigmoid : verify_program<test_sigmoid>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1238,6 +1457,114 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -1238,6 +1457,114 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
} }
}; };
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
return p;
}
};
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
return p;
}
};
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
return p;
}
};
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous> struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1367,6 +1694,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -1367,6 +1694,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
} }
}; };
struct quant_conv : verify_program<quant_conv>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{}, pa, pc);
return p;
}
};
struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p;
}
};
struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
pa,
pc);
return p;
}
};
struct quant_conv_padding : verify_program<quant_conv_padding>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc);
return p;
}
};
struct test_concat : verify_program<test_concat> struct test_concat : verify_program<test_concat>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1441,6 +1845,22 @@ struct test_pad : verify_program<test_pad> ...@@ -1441,6 +1845,22 @@ struct test_pad : verify_program<test_pad>
} }
}; };
struct test_pad_int8 : verify_program<test_pad_int8>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<int8_t> data0 = {0, 1, 2, 3};
migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{};
op.value = std::numeric_limits<int8_t>::lowest();
op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0);
return p;
}
};
struct test_pooling_autopad : verify_program<test_pooling_autopad> struct test_pooling_autopad : verify_program<test_pooling_autopad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -2631,10 +3051,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last> ...@@ -2631,10 +3051,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction( auto output = p.add_instruction(
migraphx::op::gru{hidden_size, migraphx::op::lstm{
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, hidden_size,
migraphx::op::rnn_direction::forward, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
clip}, migraphx::op::rnn_direction::forward,
clip},
seq, seq,
w, w,
r, r,
...@@ -3308,33 +3729,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3308,33 +3729,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template struct test_logsoftmax<4>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3342,8 +3743,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3342,8 +3743,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
} }
}; };
template struct test_logsoftmax_1<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax_1<1>; template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
...@@ -3356,7 +3765,7 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> ...@@ -3356,7 +3765,7 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s); auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"all"}); migraphx::quantize_fp16(p, {"all"});
return p; return p;
}; };
}; };
...@@ -3372,7 +3781,7 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd> ...@@ -3372,7 +3781,7 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s); auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"add"}); migraphx::quantize_fp16(p, {"add"});
return p; return p;
}; };
}; };
...@@ -3388,7 +3797,7 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add> ...@@ -3388,7 +3797,7 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1); p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"add"}); migraphx::quantize_fp16(p, {"add"});
return p; return p;
}; };
...@@ -3405,7 +3814,134 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub> ...@@ -3405,7 +3814,134 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1); p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"sub"}); migraphx::quantize_fp16(p, {"sub"});
return p;
};
};
struct test_reduce_sum : verify_program<test_reduce_sum>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 1026, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_int : verify_program<test_reduce_sum_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_rsqrt : verify_program<test_rsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean2 : verify_program<test_reduce_mean2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 128, 768}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
struct test_round : verify_program<test_round>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::round{}, param);
return p;
};
};
struct test_convert : verify_program<test_convert>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {8, 24}};
migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p; return p;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment