Unverified Commit 305ae9a2 authored by Aaron Enye Shi's avatar Aaron Enye Shi Committed by GitHub
Browse files

Merge branch 'develop' into enable-miopen-hipclang

parents 93888cf8 dcbc9255
#ifndef MIGRAPHX_GUARD_RTGLIB_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_POW_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/pow.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_pow : binary_device<hip_pow, device::pow>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/rsqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rsqrt : unary_device<hip_rsqrt, device::rsqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqdiff.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqdiff : binary_device<hip_sqdiff, device::sqdiff>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqrt : unary_device<hip_sqrt, device::sqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -52,7 +52,11 @@
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -99,10 +103,14 @@ struct miopen_apply
add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan");
add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
......
......@@ -154,13 +154,17 @@ struct tf_parser
add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
......@@ -177,6 +181,7 @@ struct tf_parser
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
}
template <class F>
......@@ -526,6 +531,15 @@ struct tf_parser
transb = attributes.at("transpose_a").b();
}
if(contains(attributes, "adj_x"))
{
transa = attributes.at("adj_x").b();
}
if(contains(attributes, "adj_y"))
{
transb = attributes.at("adj_y").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
......@@ -764,6 +778,16 @@ struct tf_parser
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
}
instruction_ref
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front());
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
......@@ -854,72 +878,56 @@ struct tf_parser
shape::type_t shape_type{};
switch(t)
{
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
// tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF: break;
case tensorflow::DataType::DT_DOUBLE_REF: break;
case tensorflow::DataType::DT_INT32_REF: break;
case tensorflow::DataType::DT_UINT8_REF: break;
case tensorflow::DataType::DT_INT16_REF: break;
case tensorflow::DataType::DT_INT8_REF: break;
case tensorflow::DataType::DT_STRING_REF: break;
case tensorflow::DataType::DT_COMPLEX64_REF: break;
case tensorflow::DataType::DT_INT64_REF: break;
case tensorflow::DataType::DT_BOOL_REF: break;
case tensorflow::DataType::DT_QINT8_REF: break;
case tensorflow::DataType::DT_QUINT8_REF: break;
case tensorflow::DataType::DT_QINT32_REF: break;
case tensorflow::DataType::DT_BFLOAT16_REF: break;
case tensorflow::DataType::DT_QINT16_REF: break;
case tensorflow::DataType::DT_QUINT16_REF: break;
case tensorflow::DataType::DT_UINT16_REF: break;
case tensorflow::DataType::DT_COMPLEX128_REF: break;
case tensorflow::DataType::DT_HALF_REF: break;
case tensorflow::DataType::DT_RESOURCE_REF: break;
case tensorflow::DataType::DT_VARIANT_REF: break;
case tensorflow::DataType::DT_UINT32_REF: break;
case tensorflow::DataType::DT_UINT64_REF: break;
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break;
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
}
return shape_type;
......@@ -934,61 +942,59 @@ struct tf_parser
const std::string& s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
......@@ -996,11 +1002,9 @@ struct tf_parser
}
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16:
......@@ -1012,7 +1016,6 @@ struct tf_parser
case tensorflow::DataType::DT_INT64:
return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF:
......@@ -1028,43 +1031,45 @@ struct tf_parser
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
......
......@@ -542,6 +542,21 @@ TEST_CASE(erf_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}});
p.add_instruction(migraphx::op::sqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test)
{
migraphx::program p;
......@@ -556,6 +571,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pow_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::pow{}, b, e);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sin_test)
{
migraphx::program p;
......@@ -1778,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12)
EXPECT(results_vector == gold);
}
TEST_CASE(rsqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
p.add_instruction(migraphx::op::rsqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
......@@ -1853,4 +1897,19 @@ TEST_CASE(reduce_mean_int)
EXPECT(results_vector == gold);
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -255,6 +255,19 @@ struct test_erf : verify_program<test_erf>
}
};
struct test_sqrt : verify_program<test_sqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
auto param_abs = p.add_instruction(migraphx::op::abs{}, param);
p.add_instruction(migraphx::op::sqrt{}, param_abs);
return p;
}
};
struct test_log : verify_program<test_log>
{
migraphx::program create_program() const
......@@ -267,6 +280,20 @@ struct test_log : verify_program<test_log>
}
};
struct test_pow : verify_program<test_pow>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e);
return p;
}
};
struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
......@@ -3543,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
};
};
struct test_rsqrt : verify_program<test_rsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
......
 cast-example:F

xy"Cast*
to test_castZ
x



b
y


B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 
7
shapey"ConstantOfShape*
value*:
Bvalue constant_of_shapeb
y



B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 

shapey"ConstantOfShapeconstant_of_shapeb
y



B
expand:
7shape"Constant*$
value**B shape_tensor

x
shapey"ExpandexpandZ
x



b
y




B
......@@ -202,6 +202,16 @@ TEST_CASE(erf_test)
EXPECT(p == prog);
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::sqrt{}, input);
auto prog = migraphx::parse_onnx("sqrt_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test)
{
migraphx::program p;
......@@ -889,4 +899,104 @@ TEST_CASE(clip_test)
EXPECT(p == prog);
}
TEST_CASE(implicit_pow_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l2, l3);
auto prog = migraphx::parse_onnx("pow_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = migraphx::parse_onnx("pow_bcast_test1.onnx");
EXPECT(p == prog);
}
TEST_CASE(cast_test)
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = migraphx::parse_onnx("cast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_float)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape1.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_int64)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_no_value_attr)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape3.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_empty_input)
{
migraphx::program p;
p.add_literal(migraphx::literal());
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape4.onnx");
EXPECT(p == prog);
}
TEST_CASE(expand_test)
{
migraphx::program p;
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = p.add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4});
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = migraphx::parse_onnx("expand_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
pow2:q

0
1out"Powpow_testZ
0




Z
1



b
out




B
pow2:u

0
1out"Powpow_testZ
0




Z
1




b
out




B
 sqrt-example:C
xy"Sqrt test_sqrtZ
x


b
y


B
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:
D
batchmatmul1 BatchMatMul01*
adj_x(*
adj_y(*
T0"
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0

rsqrtRsqrt0*
T0"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment