Commit 784dc2aa authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into multiply-add

parents 70641651 dcbc9255
#include <migraphx/gpu/device/sqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::sqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -8,7 +8,7 @@ namespace device {
void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return y - x; });
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x - y; });
}
} // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_DIV_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_POW_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DIV_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/div.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_div : binary_device<hip_div, device::div>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -88,7 +88,7 @@ struct binary_device : oper<Derived>
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
F(ctx.get_stream().get(), args[2], args[1], args[0]);
F(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2];
}
......
#ifndef MIGRAPHX_GUARD_RTGLIB_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_POW_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/pow.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_pow : binary_device<hip_pow, device::pow>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reduce_mean
{
op::reduce_mean op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reduce_mean"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/rsqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rsqrt : unary_device<hip_rsqrt, device::rsqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqdiff.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqdiff : binary_device<hip_sqdiff, device::sqdiff>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqrt : unary_device<hip_sqrt, device::sqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -26,6 +26,7 @@
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp>
......@@ -51,6 +52,11 @@
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -97,9 +103,14 @@ struct miopen_apply
add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan");
add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
......@@ -113,6 +124,7 @@ struct miopen_apply
add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_lrn_op();
add_convolution_op();
......
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/reduce_mean.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_reduce_mean::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument
hip_reduce_mean::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::reduce_mean(ctx.get_stream().get(), args.back(), args.front());
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -79,7 +79,8 @@ struct tf_parser
return result;
}
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
......@@ -87,14 +88,14 @@ struct tf_parser
if(is_nhwc)
{
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis);
return parse_axis(axis, num_dims);
});
}
return axes;
}
template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const
std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
{
if(is_nhwc)
{
......@@ -102,7 +103,7 @@ struct tf_parser
std::transform(axes.begin(),
axes.end(),
std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); });
[&](size_t axis) { return parse_axis(axis, num_dims); });
return new_axes;
}
return axes;
......@@ -117,17 +118,17 @@ struct tf_parser
std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++)
{
auto new_idx = parse_axis(i);
auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
template <class T>
T parse_axis(const T& dim) const
T parse_axis(const T& dim, const size_t num_dims) const
{
T new_dim = dim;
if(is_nhwc)
if(is_nhwc and num_dims >= 4)
{
switch(dim)
{
......@@ -153,18 +154,23 @@ struct tf_parser
add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
......@@ -175,6 +181,7 @@ struct tf_parser
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
}
template <class F>
......@@ -490,6 +497,25 @@ struct tf_parser
return prog.add_instruction(op, {l0, new_weights});
}
instruction_ref
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return prog.add_instruction(op::reshape{new_dims}, args[0]);
}
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -505,6 +531,15 @@ struct tf_parser
transb = attributes.at("transpose_a").b();
}
if(contains(attributes, "adj_x"))
{
transa = attributes.at("adj_x").b();
}
if(contains(attributes, "adj_y"))
{
transb = attributes.at("adj_y").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
......@@ -519,11 +554,12 @@ struct tf_parser
instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4)
{
op::pooling op{"average"};
......@@ -694,14 +730,15 @@ struct tf_parser
std::vector<instruction_ref> args)
{
op::squeeze op;
auto axes = attributes.at("squeeze_dims").list().i();
auto input_dims = args[0]->get_shape().lens();
auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(size_t i = 0; i < args0_dims.size(); i++)
for(size_t i = 0; i < input_dims.size(); i++)
{
if(args0_dims.at(i) == 1)
if(input_dims.at(i) == 1)
{
op.axes.push_back(i);
}
......@@ -741,6 +778,16 @@ struct tf_parser
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
}
instruction_ref
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front());
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
......@@ -831,72 +878,56 @@ struct tf_parser
shape::type_t shape_type{};
switch(t)
{
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
// tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF: break;
case tensorflow::DataType::DT_DOUBLE_REF: break;
case tensorflow::DataType::DT_INT32_REF: break;
case tensorflow::DataType::DT_UINT8_REF: break;
case tensorflow::DataType::DT_INT16_REF: break;
case tensorflow::DataType::DT_INT8_REF: break;
case tensorflow::DataType::DT_STRING_REF: break;
case tensorflow::DataType::DT_COMPLEX64_REF: break;
case tensorflow::DataType::DT_INT64_REF: break;
case tensorflow::DataType::DT_BOOL_REF: break;
case tensorflow::DataType::DT_QINT8_REF: break;
case tensorflow::DataType::DT_QUINT8_REF: break;
case tensorflow::DataType::DT_QINT32_REF: break;
case tensorflow::DataType::DT_BFLOAT16_REF: break;
case tensorflow::DataType::DT_QINT16_REF: break;
case tensorflow::DataType::DT_QUINT16_REF: break;
case tensorflow::DataType::DT_UINT16_REF: break;
case tensorflow::DataType::DT_COMPLEX128_REF: break;
case tensorflow::DataType::DT_HALF_REF: break;
case tensorflow::DataType::DT_RESOURCE_REF: break;
case tensorflow::DataType::DT_VARIANT_REF: break;
case tensorflow::DataType::DT_UINT32_REF: break;
case tensorflow::DataType::DT_UINT64_REF: break;
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break;
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
}
return shape_type;
......@@ -911,61 +942,59 @@ struct tf_parser
const std::string& s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
......@@ -973,11 +1002,9 @@ struct tf_parser
}
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16:
......@@ -989,7 +1016,6 @@ struct tf_parser
case tensorflow::DataType::DT_INT64:
return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF:
......@@ -1005,43 +1031,45 @@ struct tf_parser
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
......
......@@ -542,6 +542,21 @@ TEST_CASE(erf_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}});
p.add_instruction(migraphx::op::sqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test)
{
migraphx::program p;
......@@ -556,6 +571,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pow_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::pow{}, b, e);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sin_test)
{
migraphx::program p;
......@@ -1703,7 +1733,7 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_sum_test0)
TEST_CASE(reduce_sum_axis0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1718,7 +1748,7 @@ TEST_CASE(reduce_sum_test0)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test1)
TEST_CASE(reduce_sum_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1733,7 +1763,7 @@ TEST_CASE(reduce_sum_test1)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test2)
TEST_CASE(reduce_sum_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1748,7 +1778,7 @@ TEST_CASE(reduce_sum_test2)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test02)
TEST_CASE(reduce_sum_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1763,7 +1793,7 @@ TEST_CASE(reduce_sum_test02)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test12)
TEST_CASE(reduce_sum_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1778,4 +1808,108 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold);
}
TEST_CASE(rsqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
p.add_instruction(migraphx::op::rsqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -255,6 +255,19 @@ struct test_erf : verify_program<test_erf>
}
};
struct test_sqrt : verify_program<test_sqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
auto param_abs = p.add_instruction(migraphx::op::abs{}, param);
p.add_instruction(migraphx::op::sqrt{}, param_abs);
return p;
}
};
struct test_log : verify_program<test_log>
{
migraphx::program create_program() const
......@@ -267,6 +280,20 @@ struct test_log : verify_program<test_log>
}
};
struct test_pow : verify_program<test_pow>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e);
return p;
}
};
struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
......@@ -581,6 +608,38 @@ struct test_sub2 : verify_program<test_sub2>
}
};
struct test_div : verify_program<test_div>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto diff = p.add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, z);
return p;
}
};
struct test_div2 : verify_program<test_div2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, zb);
return p;
}
};
struct test_softmax1 : verify_program<test_softmax1>
{
migraphx::program create_program() const
......@@ -3511,4 +3570,53 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
};
};
struct test_rsqrt : verify_program<test_rsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment