"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "d7becadcf7ab50710066ae50848ebf7b64163a32"
Commit 6202ea15 authored by turneram's avatar turneram
Browse files

Add attention verify_onnx test

parent b745f416
attention_test:ä
B
input
weights
bias
mask_indexresult Attention_0" Attentionattention_testZ
input


€
€Z
weights

€
€Z
bias

€Z
mask_index


€b
result


€
€B
\ No newline at end of file
# This script generates onnx files for MIGraphX onnx operator tests. # This script generates onnx files for MIGraphX onnx operator tests.
# To generate an individual onnx file, you can use the following # To generate an individual onnx file, you can use the following
# command: python -c "import gen_onnx; gen_onnx.{test_name}_test()" # command: python -c "import gen_onnx; gen_onnx.{test_name}_test()"
from audioop import bias
import numpy as np import numpy as np
import onnx import onnx
from onnx import helper from onnx import helper
...@@ -187,6 +188,22 @@ def atanh_test(): ...@@ -187,6 +188,22 @@ def atanh_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def attention_test():
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [2, 384, 768])
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT, [768, 2304])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64, [2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [2, 384, 768])
node = helper.make_node('Attention',
inputs=['input', 'weights', 'bias', 'mask_index'],
outputs=['result'],
name="Attention_0")
return ([node], [input, weights, bias, mask_index], [result])
@onnx_test @onnx_test
def averagepool_1d_test(): def averagepool_1d_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5])
......
...@@ -9,6 +9,32 @@ ...@@ -9,6 +9,32 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(attention_test)
{
auto p = migraphx::parse_onnx("attention_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_i{migraphx::shape::float_type, {2, 384, 768}};
migraphx::shape s_w{migraphx::shape::float_type, {768, 2304}};
migraphx::shape s_b{migraphx::shape::float_type, {2304}};
migraphx::shape s_m{migraphx::shape::int64_type, {2, 384}};
std::vector<float> input_v(2 * 384 * 768, 1);
std::vector<float> weights_v(768 * 2304, 1);
std::vector<float> bias_v(2304, 1);
std::vector<int64_t> mask_index_v(2 * 384, 1);
migraphx::parameter_map pp;
pp["input"] = migraphx::argument(s_i, input_v.data());
pp["weights"] = migraphx::argument(s_w, weights_v.data());
pp["bias"] = migraphx::argument(s_b, bias_v.data());
pp["mask_index"] = migraphx::argument(s_m, mask_index_v.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(2 * 384 * 768, 769);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(averagepool_notset_test) TEST_CASE(averagepool_notset_test)
{ {
auto p = migraphx::parse_onnx("averagepool_notset_test.onnx"); auto p = migraphx::parse_onnx("averagepool_notset_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment