Commit 9b5e0c18 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into stream_execution_checkin

parents 00442cd2 3499ec7d
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "attr_value.proto";
import "types.proto";
// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;
// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;
// Human readable description.
string description = 2;
// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;
// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};
// Description of the input(s).
repeated ArgDef input_arg = 2;
// Description of the output(s).
repeated ArgDef output_arg = 3;
// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;
// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;
// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;
// Human-readable description.
string description = 4;
// TODO(josh11b): bool is_optional?
// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.
// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;
// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;
// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;
// One-line human-readable description of what the Op does.
string summary = 5;
// Additional, longer human-readable description of what the Op does.
string description = 6;
// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.
// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;
// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
// TODO(josh11b): Implement that optimization.
bool is_aggregate = 16; // for things like add
// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.
// -------------------------------------------------------------------------
// Optimization constraints.
// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue
// -------------------------------------------------------------------------
// Non-standard options.
// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)
// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;
// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};
// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};
#include <migraphx/tf.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
migraphx::program::parameter_map create_param_map(const migraphx::program& p, bool gpu = true)
{
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(gpu)
m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
else
m[x.first] = migraphx::generate_argument(x.second);
}
return m;
}
int main(int argc, char const* argv[])
{
if(argc > 1)
{
bool is_nhwc = true;
if(argc > 2)
{
if(strcmp(argv[2], "nchw") == 0)
is_nhwc = false;
}
std::string file = argv[1];
std::size_t n = argc > 3 ? std::stoul(argv[3]) : 50;
auto p = migraphx::parse_tf(file, is_nhwc);
std::cout << "Compiling ... " << std::endl;
p.compile(migraphx::gpu::target{});
std::cout << "Allocating params ... " << std::endl;
auto m = create_param_map(p);
std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m);
}
}
#include <migraphx/tf.hpp>
int main(int argc, char const* argv[])
{
if(argc > 1)
{
bool is_nhwc = true;
if(argc > 2)
{
if(strcmp(argv[2], "nchw") == 0)
is_nhwc = false;
}
std::string file = argv[1];
auto prog = migraphx::parse_tf(file, is_nhwc);
std::cout << prog << std::endl;
}
}
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;
// Container in which this resource is placed.
string container = 2;
// Unique name of this resource.
string name = 3;
// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;
// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";
// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
TensorShapeProto tensor_shape = 2;
// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.
// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;
// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.
// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];
// DT_FLOAT.
repeated float float_val = 5 [packed = true];
// DT_DOUBLE.
repeated double double_val = 6 [packed = true];
// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];
// DT_STRING
repeated bytes string_val = 8;
// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];
// DT_INT64
repeated int64 int64_val = 10 [packed = true];
// DT_BOOL
repeated bool bool_val = 11 [packed = true];
// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];
// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;
// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;
// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];
// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};
// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}
// Protocol buffer representing the shape of tensors.
syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
package tensorflow;
// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;
// Optional name of the tensor dimension.
string name = 2;
};
// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;
// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
bool is_nhwc = true;
std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
if(is_nhwc)
{
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis);
});
}
return axes;
}
template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const
{
std::vector<T> new_axes;
if(is_nhwc)
{
std::transform(axes.begin(),
axes.end(),
std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); });
}
return new_axes;
}
// tf stores certain attributes such as strides, dilations, as a 4D input.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// This helper function reorders the data to store for the respective operator member variables.
template <class T>
void reorder_data(std::vector<T>& prev_data) const
{
std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++)
{
auto new_idx = parse_axis(i);
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
template <class T>
T parse_axis(const T& dim) const
{
T new_dim = dim;
if(is_nhwc)
{
switch(dim)
{
case 0: new_dim = 0; break;
case 1: new_dim = 2; break;
case 2: new_dim = 3; break;
case 3: new_dim = 1; break;
default: break;
}
}
return new_dim;
}
std::vector<int64_t> get_axes(size_t num_axes) const
{
std::vector<int64_t> axes(num_axes);
std::iota(axes.begin(), axes.end(), 0);
return axes;
}
tf_parser()
{
add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_binary_op("Add", op::add{});
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze);
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
template <class F>
void add_mem_op(std::string name, F f)
{
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_binary_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
auto l0 = args[1];
if(contains(attributes, "data_format"))
{
if(is_nhwc)
{
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
}
}
return add_broadcastable_binary_op(args[0], l0, x);
});
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape() != arg1->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<size_t>* s0 = &arg0->get_shape().lens();
const std::vector<size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
}
}
template <class T>
void add_generic_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(attributes, "epsilon"))
{
epsilon = attributes.at("epsilon").f();
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0);
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
// get index for axis within args
size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>());
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::convolution op;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
return prog.add_instruction(op, {args[0], weights});
}
instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3};
if(axes == hw_axes and keep_dims)
{
op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()};
op.lengths[0] = input_dims[2];
op.lengths[1] = input_dims[3];
return prog.add_instruction(op, args.front());
}
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2 * ndims; i += 2)
{
pad_per_dim[i / 2].first = tf_padding[i];
pad_per_dim[i / 2].second = tf_padding[i + 1];
}
reorder_data(pad_per_dim);
op::pad op;
std::vector<int64_t> pads(ndims * 2);
for(size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i + ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return prog.add_instruction(op, args.front());
}
instruction_ref parse_pooling(const std::string& name,
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::pooling op{starts_with(name, "Max") ? "max" : "average"};
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(attributes, "ksize"))
{
std::vector<size_t> ksize;
copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
reorder_data(ksize);
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
op::reshape op;
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]);
}
void parse_from(std::istream& is)
{
tensorflow::GraphDef graph;
if(graph.ParseFromIstream(&is))
{
this->parse_graph(graph);
}
else
{
throw std::runtime_error("Failed reading tf file");
}
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r =
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(op::softmax{}, r);
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref parse_squeeze(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::squeeze op;
auto axes = parse_axes(attributes, "squeeze_dims");
copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(size_t i = 0; i < args0_dims.size(); i++)
{
if(args0_dims.at(i) == 1)
{
op.axes.push_back(i);
}
}
}
return prog.add_instruction(op, args[0]);
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
for(auto&& input : input_nodes)
{
const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
if(is_nhwc and dims.size() >= 4)
{
reorder_data(dims);
}
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
}
for(auto&& p : nodes)
{
this->parse_node(p.first);
}
}
void parse_node(const std::string& name)
{
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
}
else
{
args.push_back(instructions.at(input));
}
}
if(ops.count(node.op()) == 0)
{
instructions[name] = prog.add_instruction(unknown{node.op()}, args);
}
else
{
instructions[name] = ops[node.op()](get_attributes(node), args);
}
}
}
static attribute_map get_attributes(const tensorflow::NodeDef& node)
{
attribute_map result;
for(auto&& attr : node.attr())
{
result[attr.first] = attr.second;
}
return result;
}
static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
static node_map get_nodes(const tensorflow::GraphDef& graph,
std::vector<tensorflow::NodeDef>& input_nodes)
{
node_map result;
for(auto&& node : graph.node())
{
auto node_name = get_name(node);
// assume each node in graph has an associated name
if(node_name.empty())
MIGRAPHX_THROW("tf node with no name found");
result[node_name] = node;
if(node.op() == "Placeholder")
{
input_nodes.push_back(node);
}
}
return result;
}
static shape::type_t parse_type(const tensorflow::DataType t)
{
shape::type_t shape_type{};
switch(t)
{
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
// tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF: break;
case tensorflow::DataType::DT_DOUBLE_REF: break;
case tensorflow::DataType::DT_INT32_REF: break;
case tensorflow::DataType::DT_UINT8_REF: break;
case tensorflow::DataType::DT_INT16_REF: break;
case tensorflow::DataType::DT_INT8_REF: break;
case tensorflow::DataType::DT_STRING_REF: break;
case tensorflow::DataType::DT_COMPLEX64_REF: break;
case tensorflow::DataType::DT_INT64_REF: break;
case tensorflow::DataType::DT_BOOL_REF: break;
case tensorflow::DataType::DT_QINT8_REF: break;
case tensorflow::DataType::DT_QUINT8_REF: break;
case tensorflow::DataType::DT_QINT32_REF: break;
case tensorflow::DataType::DT_BFLOAT16_REF: break;
case tensorflow::DataType::DT_QINT16_REF: break;
case tensorflow::DataType::DT_QUINT16_REF: break;
case tensorflow::DataType::DT_UINT16_REF: break;
case tensorflow::DataType::DT_COMPLEX128_REF: break;
case tensorflow::DataType::DT_HALF_REF: break;
case tensorflow::DataType::DT_RESOURCE_REF: break;
case tensorflow::DataType::DT_VARIANT_REF: break;
case tensorflow::DataType::DT_UINT32_REF: break;
case tensorflow::DataType::DT_UINT64_REF: break;
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break;
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
}
return shape_type;
}
static literal parse_tensor(const tensorflow::TensorProto& t)
{
std::vector<size_t> dims = parse_dims(t.tensor_shape());
if(dims.empty())
{
dims = {1};
}
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data
{
const std::string& s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, get_data_vals(t.float_val(), shape_size)};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, get_data_vals(t.int64_val(), shape_size)};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)};
case tensorflow::DataType::DT_HALF:
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half};
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
if(data.size() == 1)
{
std::fill(data_vals.begin(), data_vals.end(), data[0]);
}
else
copy(data.begin(), data.end(), std::back_inserter(data_vals));
return data_vals;
}
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
{
std::vector<size_t> dims;
auto input_dims = s.dim();
std::transform(input_dims.begin(),
input_dims.end(),
std::back_inserter(dims),
[](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); });
return dims;
}
};
program parse_tf(const std::string& name, bool is_nhwc)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf_parser parser;
parser.is_nhwc = is_nhwc;
#ifndef NDEBUG
// Log the program when it can't be parsed
try
{
parser.parse_from(input);
}
catch(...)
{
std::cerr << parser.prog << std::endl;
throw;
}
#else
parser.parse_from(input);
#endif
return std::move(parser.prog);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;
// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;
// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
#include <migraphx/tf.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
template <class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
template <class F>
migraphx::argument run_cpu(F f)
{
auto p = f();
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
}
auto out = p.eval(m);
std::cout << p << std::endl;
return out;
}
template <class F>
migraphx::argument run_gpu(F f)
{
auto p = f();
p.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] =
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
}
auto out = migraphx::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl;
return migraphx::gpu::from_gpu(out);
}
template <class F>
void verify_program(const std::string& name, F f, double tolerance = 100)
{
auto x = run_cpu(f);
auto y = run_gpu(f);
migraphx::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const migraphx::program& prog, double tolerance = 80)
{
for(auto&& ins : prog)
{
if(ins.name().front() == '@')
continue;
if(ins.name() == "broadcast")
continue;
if(ins.name() == "transpose")
continue;
if(ins.name() == "reshape")
continue;
auto create_program = [&] {
migraphx::program p;
std::vector<migraphx::instruction_ref> inputs;
for(auto&& arg : ins.inputs())
{
if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->get_literal()));
else
inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.get_operator(), inputs);
return p;
};
try
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << create_program() << std::endl;
verify_program(ins.name(), create_program, tolerance);
}
catch(...)
{
std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw;
}
}
}
template <class F>
void verify_reduced(F f, int n, double tolerance = 80)
{
auto create_program = [&] {
migraphx::program p = f();
auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end());
return p;
};
std::cout << "Verify: " << std::endl;
std::cout << create_program() << std::endl;
verify_program(std::to_string(n), create_program, tolerance);
}
template <class F>
void verify_reduced_program(F f, double tolerance = 80)
{
migraphx::program p = f();
auto n = std::distance(p.begin(), p.end());
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(f, i, tolerance);
}
}
int main(int argc, char const* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
if(not args.empty())
{
bool is_nhwc = true;
if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "nchw"; }))
{
is_nhwc = false;
}
std::string file = args.front();
auto p = migraphx::parse_tf(file, is_nhwc);
std::cout << p << std::endl;
if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
{
verify_instructions(p);
}
else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
{
verify_reduced_program([&] { return migraphx::parse_tf(file, is_nhwc); });
}
else
{
verify_program(file, [&] { return migraphx::parse_tf(file, is_nhwc); });
}
}
}
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;
// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;
// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};
...@@ -126,6 +126,15 @@ foreach(ONNX_TEST ${ONNX_TESTS}) ...@@ -126,6 +126,15 @@ foreach(ONNX_TEST ${ONNX_TESTS})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
endforeach() endforeach()
# tf test
add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf)
target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_dependencies(tests test_tf)
add_dependencies(check test_tf)
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py) add_subdirectory(py)
endif() endif()
......
...@@ -925,6 +925,193 @@ void gemm_test() ...@@ -925,6 +925,193 @@ void gemm_test()
TEST_CASE_REGISTER(gemm_test<float>) TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>) TEST_CASE_REGISTER(gemm_test<double>)
template <class T>
void gemm_test_ex()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test_ex<float>)
TEST_CASE_REGISTER(gemm_test_ex<double>)
TEST_CASE(gemm_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
-1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882,
0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693,
0.03852506, -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
0.08698452, 0.94974586, 0.40087323, -0.05448534, 0.03220677, -1.22494296,
0.97938472, -1.43714454, -0.80430904, -0.08098728, 0.31520301, 0.49642169,
-1.63471091, 0.34390096, 2.81292176, -0.22666528, 1.54559556, -1.51075762};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.33170529, 2.26325120, -0.50639461, 0.64802947, 0.44748888, 0.33768068,
-0.53621075, 0.34341460, 0.58742520, -1.13995790, -0.99322535, 0.35447353,
0.01977110, -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
-0.39112198, 1.42303608, -0.02853060, 1.52610164, 0.53540909, 0.75618998,
-0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803,
-1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1036,6 +1223,176 @@ TEST_CASE(softmax_test) ...@@ -1036,6 +1223,176 @@ TEST_CASE(softmax_test)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_0)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-2.71138556, -5.85030702, -3.74063578, -4.22915517, -6.15821977, -5.96072346, -3.57208097,
-5.78313166, -5.51435497, -3.67224195, -3.88393048, -2.57061599, -5.54431083, -6.27880025,
-5.1878749, -6.1318955, -5.29178545, -4.22537886, -3.75693516, -7.07047099, -4.45763333,
-4.66281846, -6.18290503, -4.11886536, -6.17408292, -4.18030052, -4.64570814, -4.64354473,
-3.06629525, -3.80807681, -4.69162374, -5.53605222, -3.20969275, -4.82645674, -6.63942356,
-4.73634471, -3.86003866, -5.32738981, -4.22249802, -4.51258693, -2.41455206, -3.48343199,
-5.86215889, -4.93435935, -4.83713408, -2.97471885, -2.16666459, -3.69133151, -4.71640968,
-5.64652924, -3.60709827, -5.87967748, -3.8809403, -4.33917815};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 0;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_1)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-1.77931988, -4.91824134, -2.80857010, -3.29708949, -5.22615409, -5.02865778, -2.64001529,
-4.85106598, -4.58228929, -2.74017627, -2.95186480, -1.63855031, -4.61224515, -5.34673457,
-4.25580922, -5.19982982, -4.35971977, -3.29331318, -2.82486948, -6.13840531, -3.52556765,
-3.73075278, -5.25083935, -3.18679968, -5.24201724, -3.24823484, -3.71364246, -4.14309917,
-2.56584969, -3.30763125, -4.19117818, -5.03560666, -2.70924719, -4.32601118, -6.13897800,
-4.23589915, -3.35959310, -4.82694425, -3.72205246, -4.01214137, -1.91410650, -2.98298643,
-5.36171333, -4.43391379, -4.33668852, -2.47427329, -1.66621903, -3.19088595, -4.21596412,
-5.14608368, -3.10665271, -5.37923192, -3.38049474, -3.83873259};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_2)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.79763715, -3.93655861, -1.82688737, -2.31540676, -4.24447136, -4.04697505, -1.65833256,
-3.86938325, -3.60060656, -1.81223672, -2.02392525, -0.71061076, -3.68430560, -4.41879502,
-3.32786967, -4.27189027, -3.43178022, -2.36537363, -1.35498658, -4.66852241, -2.05568475,
-2.26086988, -3.78095645, -1.71691678, -3.77213434, -1.77835194, -2.24375956, -2.74631770,
-1.16906822, -1.91084978, -2.79439671, -3.63882519, -1.31246572, -2.92922971, -4.74219653,
-2.83911768, -2.19738500, -3.66473615, -2.55984436, -2.84993327, -0.75189840, -1.82077833,
-4.19950523, -3.27170569, -3.17448042, -1.65286841, -0.84481415, -2.36948107, -3.39455924,
-4.32467880, -2.28524783, -4.55782704, -2.55908986, -3.01732771};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 2;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_3)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.33690375, -3.47582521, -1.36615397, -0.27936556, -2.20843016, -2.01093385, -0.22551114,
-2.43656183, -2.16778514, -1.57241522, -1.78410375, -0.47078926, -1.06745881, -1.80194823,
-0.71102288, -2.30719726, -1.46708721, -0.40068062, -0.42698261, -3.74051844, -1.12768078,
-1.07891856, -2.59900513, -0.53496546, -2.56139951, -0.56761711, -1.03302473, -2.09771276,
-0.52046328, -1.26224484, -1.76322959, -2.60765807, -0.28129860, -0.81424303, -2.62720985,
-0.72413100, -0.65570381, -2.12305496, -1.01816317, -2.48063402, -0.38259915, -1.45147908,
-1.84310238, -0.91530284, -0.81807757, -1.31692881, -0.50887455, -2.03354147, -1.48767160,
-2.41779116, -0.37836019, -2.56853147, -0.56979429, -1.02803214};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 3;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_4)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 4;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <future> #include <future>
#include <thread> #include <thread>
#include "test.hpp" #include <test.hpp>
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
...@@ -136,7 +136,7 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -136,7 +136,7 @@ migraphx::argument run_gpu(migraphx::program& p)
} }
template <class V> template <class V>
void verify_program() void run_verify_program()
{ {
auto_print::set_terminate_handler(migraphx::get_type_name<V>()); auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl; // std::cout << migraphx::get_type_name<V>() << std::endl;
...@@ -158,7 +158,27 @@ void verify_program() ...@@ -158,7 +158,27 @@ void verify_program()
std::set_terminate(nullptr); std::set_terminate(nullptr);
} }
struct test_literals template <class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template <class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
template <class T>
int verify_program<T>::static_register = auto_register_verify_program<T>(); // NOLINT
struct test_literals : verify_program<test_literals>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -173,7 +193,7 @@ struct test_literals ...@@ -173,7 +193,7 @@ struct test_literals
} }
}; };
struct test_add struct test_add : verify_program<test_add>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -186,7 +206,7 @@ struct test_add ...@@ -186,7 +206,7 @@ struct test_add
} }
}; };
struct test_add_half struct test_add_half : verify_program<test_add_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -199,7 +219,7 @@ struct test_add_half ...@@ -199,7 +219,7 @@ struct test_add_half
} }
}; };
struct test_mul struct test_mul : verify_program<test_mul>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -212,7 +232,7 @@ struct test_mul ...@@ -212,7 +232,7 @@ struct test_mul
} }
}; };
struct test_exp struct test_exp : verify_program<test_exp>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -225,7 +245,7 @@ struct test_exp ...@@ -225,7 +245,7 @@ struct test_exp
} }
}; };
struct test_log struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -238,7 +258,7 @@ struct test_log ...@@ -238,7 +258,7 @@ struct test_log
} }
}; };
struct test_sin struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -250,7 +270,7 @@ struct test_sin ...@@ -250,7 +270,7 @@ struct test_sin
} }
}; };
struct test_cos struct test_cos : verify_program<test_cos>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -262,7 +282,7 @@ struct test_cos ...@@ -262,7 +282,7 @@ struct test_cos
} }
}; };
struct test_tan struct test_tan : verify_program<test_tan>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -274,7 +294,7 @@ struct test_tan ...@@ -274,7 +294,7 @@ struct test_tan
} }
}; };
struct test_sinh struct test_sinh : verify_program<test_sinh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -286,7 +306,7 @@ struct test_sinh ...@@ -286,7 +306,7 @@ struct test_sinh
} }
}; };
struct test_cosh struct test_cosh : verify_program<test_cosh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -298,7 +318,7 @@ struct test_cosh ...@@ -298,7 +318,7 @@ struct test_cosh
} }
}; };
struct test_tanh struct test_tanh : verify_program<test_tanh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -309,7 +329,7 @@ struct test_tanh ...@@ -309,7 +329,7 @@ struct test_tanh
} }
}; };
struct test_asin struct test_asin : verify_program<test_asin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -321,7 +341,7 @@ struct test_asin ...@@ -321,7 +341,7 @@ struct test_asin
} }
}; };
struct test_acos struct test_acos : verify_program<test_acos>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -333,7 +353,7 @@ struct test_acos ...@@ -333,7 +353,7 @@ struct test_acos
} }
}; };
struct test_atan struct test_atan : verify_program<test_atan>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -345,7 +365,7 @@ struct test_atan ...@@ -345,7 +365,7 @@ struct test_atan
} }
}; };
struct test_scale struct test_scale : verify_program<test_scale>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -359,7 +379,7 @@ struct test_scale ...@@ -359,7 +379,7 @@ struct test_scale
} }
}; };
struct test_slice struct test_slice : verify_program<test_slice>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -374,7 +394,7 @@ struct test_slice ...@@ -374,7 +394,7 @@ struct test_slice
} }
}; };
struct test_triadd struct test_triadd : verify_program<test_triadd>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -389,7 +409,7 @@ struct test_triadd ...@@ -389,7 +409,7 @@ struct test_triadd
} }
}; };
struct test_triadd2 struct test_triadd2 : verify_program<test_triadd2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -406,7 +426,7 @@ struct test_triadd2 ...@@ -406,7 +426,7 @@ struct test_triadd2
} }
}; };
struct test_add_broadcast struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -420,7 +440,7 @@ struct test_add_broadcast ...@@ -420,7 +440,7 @@ struct test_add_broadcast
} }
}; };
struct test_add_broadcast2 struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -434,7 +454,7 @@ struct test_add_broadcast2 ...@@ -434,7 +454,7 @@ struct test_add_broadcast2
} }
}; };
struct test_add_broadcast3 struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -448,7 +468,7 @@ struct test_add_broadcast3 ...@@ -448,7 +468,7 @@ struct test_add_broadcast3
} }
}; };
struct test_add_broadcast4 struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -462,7 +482,7 @@ struct test_add_broadcast4 ...@@ -462,7 +482,7 @@ struct test_add_broadcast4
} }
}; };
struct test_add_broadcast5 struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -476,7 +496,7 @@ struct test_add_broadcast5 ...@@ -476,7 +496,7 @@ struct test_add_broadcast5
} }
}; };
struct test_triadd_broadcast struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -492,7 +512,7 @@ struct test_triadd_broadcast ...@@ -492,7 +512,7 @@ struct test_triadd_broadcast
} }
}; };
struct test_sub struct test_sub : verify_program<test_sub>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -507,7 +527,7 @@ struct test_sub ...@@ -507,7 +527,7 @@ struct test_sub
} }
}; };
struct test_sub2 struct test_sub2 : verify_program<test_sub2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -524,7 +544,7 @@ struct test_sub2 ...@@ -524,7 +544,7 @@ struct test_sub2
} }
}; };
struct test_softmax struct test_softmax : verify_program<test_softmax>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -535,7 +555,7 @@ struct test_softmax ...@@ -535,7 +555,7 @@ struct test_softmax
} }
}; };
struct test_softmax2 struct test_softmax2 : verify_program<test_softmax2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -547,7 +567,7 @@ struct test_softmax2 ...@@ -547,7 +567,7 @@ struct test_softmax2
} }
}; };
struct test_conv struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -561,7 +581,7 @@ struct test_conv ...@@ -561,7 +581,7 @@ struct test_conv
} }
}; };
struct test_conv2 struct test_conv2 : verify_program<test_conv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -575,7 +595,7 @@ struct test_conv2 ...@@ -575,7 +595,7 @@ struct test_conv2
} }
}; };
struct test_group_conv struct test_group_conv : verify_program<test_group_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -591,7 +611,7 @@ struct test_group_conv ...@@ -591,7 +611,7 @@ struct test_group_conv
} }
}; };
struct test_conv_relu struct test_conv_relu : verify_program<test_conv_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -606,7 +626,7 @@ struct test_conv_relu ...@@ -606,7 +626,7 @@ struct test_conv_relu
} }
}; };
struct test_conv_relu_half struct test_conv_relu_half : verify_program<test_conv_relu_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -621,7 +641,7 @@ struct test_conv_relu_half ...@@ -621,7 +641,7 @@ struct test_conv_relu_half
} }
}; };
struct test_add_relu struct test_add_relu : verify_program<test_add_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -634,7 +654,7 @@ struct test_add_relu ...@@ -634,7 +654,7 @@ struct test_add_relu
} }
}; };
struct test_sigmoid struct test_sigmoid : verify_program<test_sigmoid>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -645,7 +665,7 @@ struct test_sigmoid ...@@ -645,7 +665,7 @@ struct test_sigmoid
} }
}; };
struct test_abs struct test_abs : verify_program<test_abs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -656,7 +676,7 @@ struct test_abs ...@@ -656,7 +676,7 @@ struct test_abs
} }
}; };
struct test_leaky_relu struct test_leaky_relu : verify_program<test_leaky_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -667,7 +687,7 @@ struct test_leaky_relu ...@@ -667,7 +687,7 @@ struct test_leaky_relu
} }
}; };
struct test_elu struct test_elu : verify_program<test_elu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -678,7 +698,7 @@ struct test_elu ...@@ -678,7 +698,7 @@ struct test_elu
} }
}; };
struct test_relu_lrn struct test_relu_lrn : verify_program<test_relu_lrn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -690,7 +710,7 @@ struct test_relu_lrn ...@@ -690,7 +710,7 @@ struct test_relu_lrn
} }
}; };
struct test_conv_pooling struct test_conv_pooling : verify_program<test_conv_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -706,7 +726,7 @@ struct test_conv_pooling ...@@ -706,7 +726,7 @@ struct test_conv_pooling
} }
}; };
struct test_global_avg_pooling struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -721,7 +741,7 @@ struct test_global_avg_pooling ...@@ -721,7 +741,7 @@ struct test_global_avg_pooling
} }
}; };
struct test_global_max_pooling struct test_global_max_pooling : verify_program<test_global_max_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -736,7 +756,7 @@ struct test_global_max_pooling ...@@ -736,7 +756,7 @@ struct test_global_max_pooling
} }
}; };
struct test_gemm struct test_gemm : verify_program<test_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -748,7 +768,19 @@ struct test_gemm ...@@ -748,7 +768,19 @@ struct test_gemm
} }
}; };
struct test_gemm_half struct test_gemm_ex : verify_program<test_gemm_ex>
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
p.add_instruction(migraphx::op::dot{}, a, b);
return p;
}
};
struct test_gemm_half : verify_program<test_gemm_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -760,7 +792,7 @@ struct test_gemm_half ...@@ -760,7 +792,7 @@ struct test_gemm_half
} }
}; };
struct test_gemm_ld struct test_gemm_ld //: verify_program<test_gemm_ld>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -774,7 +806,7 @@ struct test_gemm_ld ...@@ -774,7 +806,7 @@ struct test_gemm_ld
} }
}; };
struct test_gemm_transposeb struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -787,7 +819,20 @@ struct test_gemm_transposeb ...@@ -787,7 +819,20 @@ struct test_gemm_transposeb
} }
}; };
struct test_gemm_transposea struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto bt = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, b);
p.add_instruction(migraphx::op::dot{}, a, bt);
return p;
}
};
struct test_gemm_transposea : verify_program<test_gemm_transposea>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -800,7 +845,20 @@ struct test_gemm_transposea ...@@ -800,7 +845,20 @@ struct test_gemm_transposea
} }
}; };
struct test_gemm_transposeab struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
auto at = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, a);
p.add_instruction(migraphx::op::dot{}, at, b);
return p;
}
};
struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -814,7 +872,39 @@ struct test_gemm_transposeab ...@@ -814,7 +872,39 @@ struct test_gemm_transposeab
} }
}; };
struct test_contiguous struct gemm_mutli_dim_2
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_mutli_dim_2_3
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -827,7 +917,7 @@ struct test_contiguous ...@@ -827,7 +917,7 @@ struct test_contiguous
} }
}; };
struct test_eliminate_contiguous struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -843,7 +933,7 @@ struct test_eliminate_contiguous ...@@ -843,7 +933,7 @@ struct test_eliminate_contiguous
} }
}; };
struct test_transpose struct test_transpose : verify_program<test_transpose>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -857,7 +947,7 @@ struct test_transpose ...@@ -857,7 +947,7 @@ struct test_transpose
} }
}; };
struct test_batchnorm_inference_2 struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{ {
const size_t width = 14; const size_t width = 14;
const size_t height = 14; const size_t height = 14;
...@@ -880,7 +970,7 @@ struct test_batchnorm_inference_2 ...@@ -880,7 +970,7 @@ struct test_batchnorm_inference_2
} }
}; };
struct test_batchnorm_inference struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{ {
const size_t width = 3; const size_t width = 3;
const size_t height = 3; const size_t height = 3;
...@@ -903,7 +993,7 @@ struct test_batchnorm_inference ...@@ -903,7 +993,7 @@ struct test_batchnorm_inference
} }
}; };
struct test_conv_bn struct test_conv_bn : verify_program<test_conv_bn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -924,7 +1014,7 @@ struct test_conv_bn ...@@ -924,7 +1014,7 @@ struct test_conv_bn
} }
}; };
struct test_conv_bn_relu_pooling struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -948,7 +1038,7 @@ struct test_conv_bn_relu_pooling ...@@ -948,7 +1038,7 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_concat struct test_concat : verify_program<test_concat>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -965,7 +1055,7 @@ struct test_concat ...@@ -965,7 +1055,7 @@ struct test_concat
} }
}; };
struct test_concat2 struct test_concat2 : verify_program<test_concat2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -982,7 +1072,7 @@ struct test_concat2 ...@@ -982,7 +1072,7 @@ struct test_concat2
} }
}; };
struct test_concat_relu struct test_concat_relu : verify_program<test_concat_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1003,7 +1093,7 @@ struct test_concat_relu ...@@ -1003,7 +1093,7 @@ struct test_concat_relu
} }
}; };
struct test_pad struct test_pad : verify_program<test_pad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1022,7 +1112,7 @@ struct test_pad ...@@ -1022,7 +1112,7 @@ struct test_pad
} }
}; };
struct test_pooling_autopad struct test_pooling_autopad : verify_program<test_pooling_autopad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1038,7 +1128,7 @@ struct test_pooling_autopad ...@@ -1038,7 +1128,7 @@ struct test_pooling_autopad
} }
}; };
struct test_gather struct test_gather : verify_program<test_gather>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1054,7 +1144,7 @@ struct test_gather ...@@ -1054,7 +1144,7 @@ struct test_gather
} }
}; };
struct test_gather_neg_axis struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1070,7 +1160,7 @@ struct test_gather_neg_axis ...@@ -1070,7 +1160,7 @@ struct test_gather_neg_axis
} }
}; };
struct test_gather_scalar_output struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1086,7 +1176,7 @@ struct test_gather_scalar_output ...@@ -1086,7 +1176,7 @@ struct test_gather_scalar_output
} }
}; };
struct test_gather_scalar_index struct test_gather_scalar_index : verify_program<test_gather_scalar_index>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1102,7 +1192,7 @@ struct test_gather_scalar_index ...@@ -1102,7 +1192,7 @@ struct test_gather_scalar_index
} }
}; };
struct test_gather_1d_index struct test_gather_1d_index : verify_program<test_gather_1d_index>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1164,7 +1254,7 @@ void manual_test_concat_relu() ...@@ -1164,7 +1254,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl; std::cout << result << std::endl;
} }
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
static migraphx::instruction_ref static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels) add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
...@@ -1201,7 +1291,7 @@ struct test_conv_bn_relu_pooling2 ...@@ -1201,7 +1291,7 @@ struct test_conv_bn_relu_pooling2
} }
}; };
struct test_rnn_forward struct test_rnn_forward : verify_program<test_rnn_forward>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1243,7 +1333,7 @@ struct test_rnn_forward ...@@ -1243,7 +1333,7 @@ struct test_rnn_forward
} }
}; };
struct test_rnn_forward10 struct test_rnn_forward10 : verify_program<test_rnn_forward10>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1285,7 +1375,7 @@ struct test_rnn_forward10 ...@@ -1285,7 +1375,7 @@ struct test_rnn_forward10
} }
}; };
struct test_rnn_reverse struct test_rnn_reverse : verify_program<test_rnn_reverse>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1325,7 +1415,7 @@ struct test_rnn_reverse ...@@ -1325,7 +1415,7 @@ struct test_rnn_reverse
} }
}; };
struct test_rnn_reverse2 struct test_rnn_reverse2 : verify_program<test_rnn_reverse2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1365,7 +1455,7 @@ struct test_rnn_reverse2 ...@@ -1365,7 +1455,7 @@ struct test_rnn_reverse2
} }
}; };
struct test_rnn_3args struct test_rnn_3args : verify_program<test_rnn_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1397,7 +1487,7 @@ struct test_rnn_3args ...@@ -1397,7 +1487,7 @@ struct test_rnn_3args
} }
}; };
struct test_rnn_4args struct test_rnn_4args : verify_program<test_rnn_4args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1432,7 +1522,7 @@ struct test_rnn_4args ...@@ -1432,7 +1522,7 @@ struct test_rnn_4args
} }
}; };
struct test_rnn_5args struct test_rnn_5args : verify_program<test_rnn_5args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1471,7 +1561,7 @@ struct test_rnn_5args ...@@ -1471,7 +1561,7 @@ struct test_rnn_5args
} }
}; };
struct test_rnn_bidirectional struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1513,7 +1603,7 @@ struct test_rnn_bidirectional ...@@ -1513,7 +1603,7 @@ struct test_rnn_bidirectional
} }
}; };
struct test_rnn_bidirectional10 struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1554,7 +1644,7 @@ struct test_rnn_bidirectional10 ...@@ -1554,7 +1644,7 @@ struct test_rnn_bidirectional10
} }
}; };
struct test_rnn_bi_3args struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1589,7 +1679,7 @@ struct test_rnn_bi_3args ...@@ -1589,7 +1679,7 @@ struct test_rnn_bi_3args
} }
}; };
struct test_gru_forward_last struct test_gru_forward_last : verify_program<test_gru_forward_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1633,7 +1723,7 @@ struct test_gru_forward_last ...@@ -1633,7 +1723,7 @@ struct test_gru_forward_last
} }
}; };
struct test_gru_forward_hs struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1675,7 +1765,7 @@ struct test_gru_forward_hs ...@@ -1675,7 +1765,7 @@ struct test_gru_forward_hs
} }
}; };
struct test_gru_forward_3args_und struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1711,7 +1801,7 @@ struct test_gru_forward_3args_und ...@@ -1711,7 +1801,7 @@ struct test_gru_forward_3args_und
} }
}; };
struct test_gru_forward_3args struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1743,7 +1833,7 @@ struct test_gru_forward_3args ...@@ -1743,7 +1833,7 @@ struct test_gru_forward_3args
} }
}; };
struct test_gru_forward_seq1 struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1775,7 +1865,7 @@ struct test_gru_forward_seq1 ...@@ -1775,7 +1865,7 @@ struct test_gru_forward_seq1
} }
}; };
struct test_gru_forward_default_actv struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1805,7 +1895,7 @@ struct test_gru_forward_default_actv ...@@ -1805,7 +1895,7 @@ struct test_gru_forward_default_actv
} }
}; };
struct test_gru_forward_default_actv1 struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1846,7 +1936,7 @@ struct test_gru_forward_default_actv1 ...@@ -1846,7 +1936,7 @@ struct test_gru_forward_default_actv1
} }
}; };
struct test_gru_reverse_last struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1890,7 +1980,7 @@ struct test_gru_reverse_last ...@@ -1890,7 +1980,7 @@ struct test_gru_reverse_last
} }
}; };
struct test_gru_reverse_3args struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1922,7 +2012,7 @@ struct test_gru_reverse_3args ...@@ -1922,7 +2012,7 @@ struct test_gru_reverse_3args
} }
}; };
struct test_gru_bidirct_last struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1966,7 +2056,7 @@ struct test_gru_bidirct_last ...@@ -1966,7 +2056,7 @@ struct test_gru_bidirct_last
} }
}; };
struct test_gru_bidirct_hs struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2008,7 +2098,7 @@ struct test_gru_bidirct_hs ...@@ -2008,7 +2098,7 @@ struct test_gru_bidirct_hs
} }
}; };
struct test_gru_bidirct_3args_und struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2044,7 +2134,7 @@ struct test_gru_bidirct_3args_und ...@@ -2044,7 +2134,7 @@ struct test_gru_bidirct_3args_und
} }
}; };
struct test_gru_bidirct_3args struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2076,7 +2166,7 @@ struct test_gru_bidirct_3args ...@@ -2076,7 +2166,7 @@ struct test_gru_bidirct_3args
} }
}; };
struct test_gru_bidirct_seq1 struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2108,7 +2198,7 @@ struct test_gru_bidirct_seq1 ...@@ -2108,7 +2198,7 @@ struct test_gru_bidirct_seq1
} }
}; };
struct test_gru_bidirct_default_actv struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2138,7 +2228,7 @@ struct test_gru_bidirct_default_actv ...@@ -2138,7 +2228,7 @@ struct test_gru_bidirct_default_actv
} }
}; };
struct test_gru_bidirct_default_actv1 struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2180,7 +2270,7 @@ struct test_gru_bidirct_default_actv1 ...@@ -2180,7 +2270,7 @@ struct test_gru_bidirct_default_actv1
} }
}; };
struct test_lstm_forward_last struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2230,7 +2320,7 @@ struct test_lstm_forward_last ...@@ -2230,7 +2320,7 @@ struct test_lstm_forward_last
} }
}; };
struct test_lstm_forward_hs struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2280,7 +2370,7 @@ struct test_lstm_forward_hs ...@@ -2280,7 +2370,7 @@ struct test_lstm_forward_hs
} }
}; };
struct test_lstm_forward_3args_und struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2320,7 +2410,7 @@ struct test_lstm_forward_3args_und ...@@ -2320,7 +2410,7 @@ struct test_lstm_forward_3args_und
} }
}; };
struct test_lstm_forward_3args struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2354,7 +2444,7 @@ struct test_lstm_forward_3args ...@@ -2354,7 +2444,7 @@ struct test_lstm_forward_3args
} }
}; };
struct test_lstm_forward_seq1 struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2388,7 +2478,7 @@ struct test_lstm_forward_seq1 ...@@ -2388,7 +2478,7 @@ struct test_lstm_forward_seq1
} }
}; };
struct test_lstm_forward_default_actv struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2418,7 +2508,7 @@ struct test_lstm_forward_default_actv ...@@ -2418,7 +2508,7 @@ struct test_lstm_forward_default_actv
} }
}; };
struct test_lstm_forward_default_actv1 struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2459,7 +2549,7 @@ struct test_lstm_forward_default_actv1 ...@@ -2459,7 +2549,7 @@ struct test_lstm_forward_default_actv1
} }
}; };
struct test_lstm_reverse_last struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2510,7 +2600,7 @@ struct test_lstm_reverse_last ...@@ -2510,7 +2600,7 @@ struct test_lstm_reverse_last
} }
}; };
struct test_lstm_reverse_3args struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2544,7 +2634,7 @@ struct test_lstm_reverse_3args ...@@ -2544,7 +2634,7 @@ struct test_lstm_reverse_3args
} }
}; };
struct test_lstm_reverse_3args_cell_output struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2579,7 +2669,7 @@ struct test_lstm_reverse_3args_cell_output ...@@ -2579,7 +2669,7 @@ struct test_lstm_reverse_3args_cell_output
} }
}; };
struct test_lstm_bidirct_last struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2630,7 +2720,7 @@ struct test_lstm_bidirct_last ...@@ -2630,7 +2720,7 @@ struct test_lstm_bidirct_last
} }
}; };
struct test_lstm_bidirct_hs struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2672,7 +2762,7 @@ struct test_lstm_bidirct_hs ...@@ -2672,7 +2762,7 @@ struct test_lstm_bidirct_hs
} }
}; };
struct test_lstm_bidirct_3args_und struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2711,7 +2801,7 @@ struct test_lstm_bidirct_3args_und ...@@ -2711,7 +2801,7 @@ struct test_lstm_bidirct_3args_und
} }
}; };
struct test_lstm_bidirct_3args struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2743,7 +2833,7 @@ struct test_lstm_bidirct_3args ...@@ -2743,7 +2833,7 @@ struct test_lstm_bidirct_3args
} }
}; };
struct test_lstm_bidirct_seq1 struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2775,7 +2865,7 @@ struct test_lstm_bidirct_seq1 ...@@ -2775,7 +2865,7 @@ struct test_lstm_bidirct_seq1
} }
}; };
struct test_lstm_bidirct_default_actv struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2805,7 +2895,7 @@ struct test_lstm_bidirct_default_actv ...@@ -2805,7 +2895,7 @@ struct test_lstm_bidirct_default_actv
} }
}; };
struct test_lstm_bidirct_default_actv1 struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2847,7 +2937,7 @@ struct test_lstm_bidirct_default_actv1 ...@@ -2847,7 +2937,7 @@ struct test_lstm_bidirct_default_actv1
} }
}; };
struct test_lstm_bidirct_default_actv2 struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2889,116 +2979,41 @@ struct test_lstm_bidirct_default_actv2 ...@@ -2889,116 +2979,41 @@ struct test_lstm_bidirct_default_actv2
} }
}; };
int main() template <int Axis>
{ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
verify_program<test_relu_lrn>(); {
verify_program<test_pooling_autopad>(); migraphx::program create_program() const
verify_program<test_abs>(); {
verify_program<test_concat>(); migraphx::program p;
verify_program<test_concat2>(); migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
verify_program<test_concat_relu>(); auto param = p.add_parameter("0", s);
verify_program<test_pad>(); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
verify_program<test_add>();
verify_program<test_add_half>(); return p;
verify_program<test_mul>(); }
verify_program<test_exp>(); };
verify_program<test_log>();
verify_program<test_sin>(); template struct test_logsoftmax<0>;
verify_program<test_cos>(); template struct test_logsoftmax<1>;
verify_program<test_tan>(); template struct test_logsoftmax<2>;
verify_program<test_sinh>(); template struct test_logsoftmax<3>;
verify_program<test_cosh>(); template struct test_logsoftmax<4>;
verify_program<test_tanh>();
verify_program<test_asin>(); template <int Axis>
verify_program<test_acos>(); struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
verify_program<test_atan>(); {
verify_program<test_scale>(); migraphx::program create_program() const
verify_program<test_triadd>(); {
verify_program<test_triadd2>(); migraphx::program p;
verify_program<test_add_broadcast>(); migraphx::shape s{migraphx::shape::float_type, {3}};
verify_program<test_add_broadcast2>(); auto param = p.add_parameter("0", s);
verify_program<test_add_broadcast3>(); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); return p;
verify_program<test_triadd_broadcast>(); }
verify_program<test_sub>(); };
verify_program<test_sub2>();
verify_program<test_softmax>(); template struct test_logsoftmax_1<0>;
verify_program<test_softmax2>(); template struct test_logsoftmax_1<1>;
verify_program<test_conv>();
verify_program<test_conv2>(); int main(int argc, const char* argv[]) { test::run(argc, argv); }
verify_program<test_group_conv>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_sigmoid>();
verify_program<test_elu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_gather_scalar_output>();
verify_program<test_gather_scalar_index>();
verify_program<test_gather_1d_index>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_reverse2>();
verify_program<test_rnn_3args>();
verify_program<test_rnn_4args>();
verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
verify_program<test_gru_forward_last>();
verify_program<test_gru_forward_hs>();
verify_program<test_gru_forward_3args_und>();
verify_program<test_gru_forward_3args>();
verify_program<test_gru_forward_seq1>();
verify_program<test_gru_forward_default_actv>();
verify_program<test_gru_forward_default_actv1>();
verify_program<test_gru_reverse_last>();
verify_program<test_gru_reverse_3args>();
verify_program<test_gru_bidirct_last>();
verify_program<test_gru_bidirct_hs>();
verify_program<test_gru_bidirct_3args_und>();
verify_program<test_gru_bidirct_3args>();
verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>();
verify_program<test_lstm_forward_last>();
verify_program<test_lstm_forward_hs>();
verify_program<test_lstm_forward_3args_und>();
verify_program<test_lstm_forward_3args>();
verify_program<test_lstm_forward_seq1>();
verify_program<test_lstm_forward_default_actv>();
verify_program<test_lstm_forward_default_actv1>();
verify_program<test_lstm_reverse_last>();
verify_program<test_lstm_reverse_3args>();
verify_program<test_lstm_reverse_3args_cell_output>();
verify_program<test_lstm_bidirct_last>();
verify_program<test_lstm_bidirct_hs>();
verify_program<test_lstm_bidirct_3args_und>();
verify_program<test_lstm_bidirct_3args>();
verify_program<test_lstm_bidirct_seq1>();
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
}
...@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f) ...@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(std::move(name), std::move(f)); get_test_cases().emplace_back(std::move(name), std::move(f));
} }
struct auto_register struct auto_register_test_case
{ {
template <class F> template <class F>
auto_register(const char* name, F f) noexcept auto_register_test_case(const char* name, F f) noexcept
{ {
add_test_case(name, f); add_test_case(name, f);
} }
...@@ -258,9 +258,9 @@ inline void run(int argc, const char* argv[]) ...@@ -258,9 +258,9 @@ inline void run(int argc, const char* argv[])
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__ #define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \ #define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \ static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__); test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE(...) \ #define TEST_CASE(...) \
......
shape-gather-example:O constant-scalar-example:R
2value"Constant* 00"Constant*!
value**B const_tensor constantb value**B const_tensor  test-constantb
z 0
 
B B
\ No newline at end of file
logsoftmax-example:l

xy"
LogSoftmax*
axistest_logsoftmaxZ
x




b
y




B
\ No newline at end of file
...@@ -470,8 +470,8 @@ TEST_CASE(flatten_test) ...@@ -470,8 +470,8 @@ TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{1}, l0);
p.add_instruction(migraphx::op::flatten{2}, l0); p.add_instruction(migraphx::op::flatten{2}, l0);
p.add_instruction(migraphx::op::flatten{1}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx"); auto prog = migraphx::parse_onnx("flatten_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -524,7 +524,7 @@ TEST_CASE(constant_test) ...@@ -524,7 +524,7 @@ TEST_CASE(constant_test)
TEST_CASE(constant_test_scalar) TEST_CASE(constant_test_scalar)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx"); auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -572,6 +572,27 @@ TEST_CASE(gemm_test) ...@@ -572,6 +572,27 @@ TEST_CASE(gemm_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_ex)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1);
auto beta = 0.8f;
auto l_beta = p.add_literal(beta);
auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta);
auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta);
p.add_instruction(migraphx::op::add{}, res_ab, res_c);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_scalar_test) TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -651,4 +672,15 @@ TEST_CASE(add_fp16_test) ...@@ -651,4 +672,15 @@ TEST_CASE(add_fp16_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(logsoftmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
 sum-example:e  sum-example:a
 
0 0
1 1
23"Sum test-dropoutZ 23"Sumtest-sumZ
0 0
 
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
 
b b
2 3
 
B B
\ No newline at end of file
unknown-example: unknown-example:
 
0 0
12"Unknown 12"Unknown

2"Unknown test-unknownZ 23"Unknown test-unknownZ
0 0
 
 
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
 
 
b b
2 3
 
 
 
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment