Commit 48187e79 authored by turneram's avatar turneram
Browse files

Formatting

parent 6202ea15
...@@ -190,12 +190,16 @@ def atanh_test(): ...@@ -190,12 +190,16 @@ def atanh_test():
@onnx_test @onnx_test
def attention_test(): def attention_test():
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [2, 384, 768]) input = helper.make_tensor_value_info('input', TensorProto.FLOAT,
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT, [768, 2304]) [2, 384, 768])
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT,
[768, 2304])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304]) bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64, [2, 384]) mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64,
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [2, 384, 768]) [2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT,
[2, 384, 768])
node = helper.make_node('Attention', node = helper.make_node('Attention',
inputs=['input', 'weights', 'bias', 'mask_index'], inputs=['input', 'weights', 'bias', 'mask_index'],
outputs=['result'], outputs=['result'],
......
...@@ -22,9 +22,9 @@ TEST_CASE(attention_test) ...@@ -22,9 +22,9 @@ TEST_CASE(attention_test)
std::vector<float> bias_v(2304, 1); std::vector<float> bias_v(2304, 1);
std::vector<int64_t> mask_index_v(2 * 384, 1); std::vector<int64_t> mask_index_v(2 * 384, 1);
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["input"] = migraphx::argument(s_i, input_v.data()); pp["input"] = migraphx::argument(s_i, input_v.data());
pp["weights"] = migraphx::argument(s_w, weights_v.data()); pp["weights"] = migraphx::argument(s_w, weights_v.data());
pp["bias"] = migraphx::argument(s_b, bias_v.data()); pp["bias"] = migraphx::argument(s_b, bias_v.data());
pp["mask_index"] = migraphx::argument(s_m, mask_index_v.data()); pp["mask_index"] = migraphx::argument(s_m, mask_index_v.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment