Commit 1c3b16d2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents 015d1ac4 3d200e1c
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";
// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
TensorShapeProto tensor_shape = 2;
// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.
// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;
// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.
// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];
// DT_FLOAT.
repeated float float_val = 5 [packed = true];
// DT_DOUBLE.
repeated double double_val = 6 [packed = true];
// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];
// DT_STRING
repeated bytes string_val = 8;
// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];
// DT_INT64
repeated int64 int64_val = 10 [packed = true];
// DT_BOOL
repeated bool bool_val = 11 [packed = true];
// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];
// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;
// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;
// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];
// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};
// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}
// Protocol buffer representing the shape of tensors.
syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
package tensorflow;
// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;
// Optional name of the tensor dimension.
string name = 2;
};
// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;
// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};
This diff is collapsed.
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;
// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;
// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
#include <migraphx/tf.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
template <class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
template <class F>
migraphx::argument run_cpu(F f)
{
auto p = f();
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
}
auto out = p.eval(m);
std::cout << p << std::endl;
return out;
}
template <class F>
migraphx::argument run_gpu(F f)
{
auto p = f();
p.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] =
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
}
auto out = migraphx::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl;
return migraphx::gpu::from_gpu(out);
}
template <class F>
void verify_program(const std::string& name, F f, double tolerance = 100)
{
auto x = run_cpu(f);
auto y = run_gpu(f);
migraphx::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const migraphx::program& prog, double tolerance = 80)
{
for(auto&& ins : prog)
{
if(ins.name().front() == '@')
continue;
if(ins.name() == "broadcast")
continue;
if(ins.name() == "transpose")
continue;
if(ins.name() == "reshape")
continue;
auto create_program = [&] {
migraphx::program p;
std::vector<migraphx::instruction_ref> inputs;
for(auto&& arg : ins.inputs())
{
if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->get_literal()));
else
inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.get_operator(), inputs);
return p;
};
try
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << create_program() << std::endl;
verify_program(ins.name(), create_program, tolerance);
}
catch(...)
{
std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw;
}
}
}
template <class F>
void verify_reduced(F f, int n, double tolerance = 80)
{
auto create_program = [&] {
migraphx::program p = f();
auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end());
return p;
};
std::cout << "Verify: " << std::endl;
std::cout << create_program() << std::endl;
verify_program(std::to_string(n), create_program, tolerance);
}
template <class F>
void verify_reduced_program(F f, double tolerance = 80)
{
migraphx::program p = f();
auto n = std::distance(p.begin(), p.end());
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(f, i, tolerance);
}
}
int main(int argc, char const* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
if(not args.empty())
{
bool is_nhwc = true;
if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "nchw"; }))
{
is_nhwc = false;
}
std::string file = args.front();
auto p = migraphx::parse_tf(file, is_nhwc);
std::cout << p << std::endl;
if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
{
verify_instructions(p);
}
else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
{
verify_reduced_program([&] { return migraphx::parse_tf(file, is_nhwc); });
}
else
{
verify_program(file, [&] { return migraphx::parse_tf(file, is_nhwc); });
}
}
}
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;
// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;
// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};
......@@ -126,6 +126,15 @@ foreach(ONNX_TEST ${ONNX_TESTS})
add_dependencies(check ${TEST_NAME})
endforeach()
# tf test
add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf)
target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_dependencies(tests test_tf)
add_dependencies(check test_tf)
if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py)
endif()
......
......@@ -1269,6 +1269,176 @@ TEST_CASE(softmax_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_0)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-2.71138556, -5.85030702, -3.74063578, -4.22915517, -6.15821977, -5.96072346, -3.57208097,
-5.78313166, -5.51435497, -3.67224195, -3.88393048, -2.57061599, -5.54431083, -6.27880025,
-5.1878749, -6.1318955, -5.29178545, -4.22537886, -3.75693516, -7.07047099, -4.45763333,
-4.66281846, -6.18290503, -4.11886536, -6.17408292, -4.18030052, -4.64570814, -4.64354473,
-3.06629525, -3.80807681, -4.69162374, -5.53605222, -3.20969275, -4.82645674, -6.63942356,
-4.73634471, -3.86003866, -5.32738981, -4.22249802, -4.51258693, -2.41455206, -3.48343199,
-5.86215889, -4.93435935, -4.83713408, -2.97471885, -2.16666459, -3.69133151, -4.71640968,
-5.64652924, -3.60709827, -5.87967748, -3.8809403, -4.33917815};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 0;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_1)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-1.77931988, -4.91824134, -2.80857010, -3.29708949, -5.22615409, -5.02865778, -2.64001529,
-4.85106598, -4.58228929, -2.74017627, -2.95186480, -1.63855031, -4.61224515, -5.34673457,
-4.25580922, -5.19982982, -4.35971977, -3.29331318, -2.82486948, -6.13840531, -3.52556765,
-3.73075278, -5.25083935, -3.18679968, -5.24201724, -3.24823484, -3.71364246, -4.14309917,
-2.56584969, -3.30763125, -4.19117818, -5.03560666, -2.70924719, -4.32601118, -6.13897800,
-4.23589915, -3.35959310, -4.82694425, -3.72205246, -4.01214137, -1.91410650, -2.98298643,
-5.36171333, -4.43391379, -4.33668852, -2.47427329, -1.66621903, -3.19088595, -4.21596412,
-5.14608368, -3.10665271, -5.37923192, -3.38049474, -3.83873259};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_2)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.79763715, -3.93655861, -1.82688737, -2.31540676, -4.24447136, -4.04697505, -1.65833256,
-3.86938325, -3.60060656, -1.81223672, -2.02392525, -0.71061076, -3.68430560, -4.41879502,
-3.32786967, -4.27189027, -3.43178022, -2.36537363, -1.35498658, -4.66852241, -2.05568475,
-2.26086988, -3.78095645, -1.71691678, -3.77213434, -1.77835194, -2.24375956, -2.74631770,
-1.16906822, -1.91084978, -2.79439671, -3.63882519, -1.31246572, -2.92922971, -4.74219653,
-2.83911768, -2.19738500, -3.66473615, -2.55984436, -2.84993327, -0.75189840, -1.82077833,
-4.19950523, -3.27170569, -3.17448042, -1.65286841, -0.84481415, -2.36948107, -3.39455924,
-4.32467880, -2.28524783, -4.55782704, -2.55908986, -3.01732771};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 2;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_3)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.33690375, -3.47582521, -1.36615397, -0.27936556, -2.20843016, -2.01093385, -0.22551114,
-2.43656183, -2.16778514, -1.57241522, -1.78410375, -0.47078926, -1.06745881, -1.80194823,
-0.71102288, -2.30719726, -1.46708721, -0.40068062, -0.42698261, -3.74051844, -1.12768078,
-1.07891856, -2.59900513, -0.53496546, -2.56139951, -0.56761711, -1.03302473, -2.09771276,
-0.52046328, -1.26224484, -1.76322959, -2.60765807, -0.28129860, -0.81424303, -2.62720985,
-0.72413100, -0.65570381, -2.12305496, -1.01816317, -2.48063402, -0.38259915, -1.45147908,
-1.84310238, -0.91530284, -0.81807757, -1.31692881, -0.50887455, -2.03354147, -1.48767160,
-2.41779116, -0.37836019, -2.56853147, -0.56979429, -1.02803214};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 3;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_4)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 4;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(conv2d_test)
{
migraphx::program p;
......
......@@ -2977,6 +2977,34 @@ struct test_lstm_bidirct_default_actv2
}
};
template <int Axis>
struct test_logsoftmax
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template <int Axis>
struct test_logsoftmax_1
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
int main()
{
verify_program<test_relu_lrn>();
......@@ -3095,4 +3123,11 @@ int main()
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
verify_program<test_logsoftmax<0>>();
verify_program<test_logsoftmax<1>>();
verify_program<test_logsoftmax<2>>();
verify_program<test_logsoftmax<3>>();
verify_program<test_logsoftmax<4>>();
verify_program<test_logsoftmax_1<0>>();
verify_program<test_logsoftmax_1<1>>();
}
shape-gather-example:O
2value"Constant*
value**B const_tensor constantb
z
constant-scalar-example:R
00"Constant*!
value**B const_tensor  test-constantb
0

B
\ No newline at end of file
logsoftmax-example:l

xy"
LogSoftmax*
axistest_logsoftmaxZ
x




b
y




B
\ No newline at end of file
......@@ -470,8 +470,8 @@ TEST_CASE(flatten_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{1}, l0);
p.add_instruction(migraphx::op::flatten{2}, l0);
p.add_instruction(migraphx::op::flatten{1}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx");
EXPECT(p == prog);
......@@ -524,7 +524,7 @@ TEST_CASE(constant_test)
TEST_CASE(constant_test_scalar)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {1}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog);
......@@ -666,4 +666,15 @@ TEST_CASE(add_fp16_test)
EXPECT(p == prog);
}
TEST_CASE(logsoftmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
 sum-example:e
 sum-example:a

0
1
23"Sum test-dropoutZ
23"Sumtest-sumZ
0

......@@ -15,7 +15,7 @@

b
2
3

B
\ No newline at end of file
unknown-example:
unknown-example:

0
12"Unknown
2"Unknown test-unknownZ

23"Unknown test-unknownZ
0


......@@ -14,7 +14,7 @@


b
2
3



......
......@@ -316,6 +316,61 @@ TEST_CASE(gather)
}
}
TEST_CASE(logsoftmax)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
}
TEST_CASE(dot)
{
{
......
2
0 Placeholder*
shape
:*
dtype0
2
1 Placeholder*
dtype0*
shape
:
add_bcast1Add01*
T0"
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

add1Add01*
T0"
\ No newline at end of file
;
0 Placeholder*
shape:*
dtype0
/
1 Placeholder*
dtype0*
shape:
:
bias_add1BiasAdd01*
T0*
data_formatNHWC"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment