Commit 9ba36e10 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into hcc26

parents 7c2af597 292b6aab
#ifndef MIGRAPHX_GUARD_RTGLIB_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_POW_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/pow.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_pow : binary_device<hip_pow, device::pow>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/rsqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rsqrt : unary_device<hip_rsqrt, device::rsqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqdiff.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqdiff : binary_device<hip_sqdiff, device::sqdiff>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqrt : unary_device<hip_sqrt, device::sqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/gpu/logsoftmax.hpp> #include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp> #include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp> #include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp> #include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp> #include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp> #include <migraphx/gpu/log.hpp>
...@@ -51,7 +52,11 @@ ...@@ -51,7 +52,11 @@
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp> #include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -98,9 +103,14 @@ struct miopen_apply ...@@ -98,9 +103,14 @@ struct miopen_apply
add_generic_op<hip_asin>("asin"); add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos"); add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan"); add_generic_op<hip_atan>("atan");
add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul"); add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max"); add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
......
...@@ -79,7 +79,8 @@ struct tf_parser ...@@ -79,7 +79,8 @@ struct tf_parser
return result; return result;
} }
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -87,14 +88,14 @@ struct tf_parser ...@@ -87,14 +88,14 @@ struct tf_parser
if(is_nhwc) if(is_nhwc)
{ {
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) { std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis); return parse_axis(axis, num_dims);
}); });
} }
return axes; return axes;
} }
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
...@@ -102,7 +103,7 @@ struct tf_parser ...@@ -102,7 +103,7 @@ struct tf_parser
std::transform(axes.begin(), std::transform(axes.begin(),
axes.end(), axes.end(),
std::back_inserter(new_axes), std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); }); [&](size_t axis) { return parse_axis(axis, num_dims); });
return new_axes; return new_axes;
} }
return axes; return axes;
...@@ -117,17 +118,17 @@ struct tf_parser ...@@ -117,17 +118,17 @@ struct tf_parser
std::vector<T> new_data(prev_data.size()); std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++) for(size_t i = 0; i < new_data.size(); i++)
{ {
auto new_idx = parse_axis(i); auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i); new_data.at(new_idx) = prev_data.at(i);
} }
prev_data = new_data; prev_data = new_data;
} }
template <class T> template <class T>
T parse_axis(const T& dim) const T parse_axis(const T& dim, const size_t num_dims) const
{ {
T new_dim = dim; T new_dim = dim;
if(is_nhwc) if(is_nhwc and num_dims >= 4)
{ {
switch(dim) switch(dim)
{ {
...@@ -153,10 +154,13 @@ struct tf_parser ...@@ -153,10 +154,13 @@ struct tf_parser
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{}); add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
...@@ -165,6 +169,7 @@ struct tf_parser ...@@ -165,6 +169,7 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
...@@ -175,6 +180,7 @@ struct tf_parser ...@@ -175,6 +180,7 @@ struct tf_parser
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
} }
template <class F> template <class F>
...@@ -490,6 +496,25 @@ struct tf_parser ...@@ -490,6 +496,25 @@ struct tf_parser
return prog.add_instruction(op, {l0, new_weights}); return prog.add_instruction(op, {l0, new_weights});
} }
instruction_ref
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return prog.add_instruction(op::reshape{new_dims}, args[0]);
}
instruction_ref instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -519,11 +544,12 @@ struct tf_parser ...@@ -519,11 +544,12 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4) if(axes == hw_axes and lens.size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
...@@ -694,14 +720,15 @@ struct tf_parser ...@@ -694,14 +720,15 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto axes = attributes.at("squeeze_dims").list().i(); auto input_dims = args[0]->get_shape().lens();
auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for(size_t i = 0; i < args0_dims.size(); i++) for(size_t i = 0; i < input_dims.size(); i++)
{ {
if(args0_dims.at(i) == 1) if(input_dims.at(i) == 1)
{ {
op.axes.push_back(i); op.axes.push_back(i);
} }
...@@ -741,6 +768,16 @@ struct tf_parser ...@@ -741,6 +768,16 @@ struct tf_parser
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0)); return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
instruction_ref
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front());
}
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
{ {
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
......
...@@ -542,6 +542,21 @@ TEST_CASE(erf_test) ...@@ -542,6 +542,21 @@ TEST_CASE(erf_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}});
p.add_instruction(migraphx::op::sqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -556,6 +571,21 @@ TEST_CASE(log_test) ...@@ -556,6 +571,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(pow_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::pow{}, b, e);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sin_test) TEST_CASE(sin_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1778,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12) ...@@ -1778,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(rsqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
p.add_instruction(migraphx::op::rsqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_mean_axis1) TEST_CASE(reduce_mean_axis1)
{ {
migraphx::program p; migraphx::program p;
...@@ -1853,4 +1897,19 @@ TEST_CASE(reduce_mean_int) ...@@ -1853,4 +1897,19 @@ TEST_CASE(reduce_mean_int)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(sqdiff_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -255,6 +255,19 @@ struct test_erf : verify_program<test_erf> ...@@ -255,6 +255,19 @@ struct test_erf : verify_program<test_erf>
} }
}; };
struct test_sqrt : verify_program<test_sqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
auto param_abs = p.add_instruction(migraphx::op::abs{}, param);
p.add_instruction(migraphx::op::sqrt{}, param_abs);
return p;
}
};
struct test_log : verify_program<test_log> struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -267,6 +280,20 @@ struct test_log : verify_program<test_log> ...@@ -267,6 +280,20 @@ struct test_log : verify_program<test_log>
} }
}; };
struct test_pow : verify_program<test_pow>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e);
return p;
}
};
struct test_sin : verify_program<test_sin> struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -581,6 +608,38 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -581,6 +608,38 @@ struct test_sub2 : verify_program<test_sub2>
} }
}; };
struct test_div : verify_program<test_div>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto diff = p.add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, z);
return p;
}
};
struct test_div2 : verify_program<test_div2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, zb);
return p;
}
};
struct test_softmax1 : verify_program<test_softmax1> struct test_softmax1 : verify_program<test_softmax1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -3511,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half> ...@@ -3511,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
}; };
}; };
struct test_rsqrt : verify_program<test_rsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean> struct test_reduce_mean : verify_program<test_reduce_mean>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -202,6 +202,16 @@ TEST_CASE(erf_test) ...@@ -202,6 +202,16 @@ TEST_CASE(erf_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sqrt_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::sqrt{}, input);
auto prog = migraphx::parse_onnx("sqrt_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -889,4 +899,30 @@ TEST_CASE(clip_test) ...@@ -889,4 +899,30 @@ TEST_CASE(clip_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(implicit_pow_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l2, l3);
auto prog = migraphx::parse_onnx("pow_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = migraphx::parse_onnx("pow_bcast_test1.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
pow2:q

0
1out"Powpow_testZ
0




Z
1



b
out




B
pow2:u

0
1out"Powpow_testZ
0




Z
1




b
out




B
 sqrt-example:C
xy"Sqrt test_sqrtZ
x


b
y


B
:
0 Placeholder*
shape:*
dtype0

rsqrtRsqrt0*
T0"
\ No newline at end of file
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
*
sqdiffSquaredDifference01*
T0"
\ No newline at end of file
:
0 Placeholder*
dtype0*
shape:
(
stopgradient StopGradient0*
T0"
\ No newline at end of file
...@@ -159,6 +159,31 @@ TEST_CASE(depthwiseconv_test) ...@@ -159,6 +159,31 @@ TEST_CASE(depthwiseconv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(expanddims_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto prog = optimize_tf("expanddims_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(expanddims_test_neg_dims)
{
// this check makes sure the pb parses negative dim value correctly
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(-1);
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
auto prog = optimize_tf("expanddims_neg_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(identity_test) TEST_CASE(identity_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -326,6 +351,16 @@ TEST_CASE(reshape_test) ...@@ -326,6 +351,16 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(rsqrt_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::rsqrt{}, l0);
auto prog = optimize_tf("rsqrt_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -339,6 +374,17 @@ TEST_CASE(softmax_test) ...@@ -339,6 +374,17 @@ TEST_CASE(softmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sqdiff_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l0, l1);
auto prog = optimize_tf("sqdiff_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(squeeze_test) TEST_CASE(squeeze_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -349,6 +395,16 @@ TEST_CASE(squeeze_test) ...@@ -349,6 +395,16 @@ TEST_CASE(squeeze_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(stopgradient_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_tf("stopgradient_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(stridedslice_test) TEST_CASE(stridedslice_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -389,4 +445,16 @@ TEST_CASE(tanh_test) ...@@ -389,4 +445,16 @@ TEST_CASE(tanh_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
p.add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment