Commit 48187e79 authored by turneram's avatar turneram
Browse files

Formatting

parent 6202ea15
......@@ -190,12 +190,16 @@ def atanh_test():
@onnx_test
def attention_test():
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [2, 384, 768])
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT, [768, 2304])
input = helper.make_tensor_value_info('input', TensorProto.FLOAT,
[2, 384, 768])
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT,
[768, 2304])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64, [2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, [2, 384, 768])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64,
[2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT,
[2, 384, 768])
node = helper.make_node('Attention',
inputs=['input', 'weights', 'bias', 'mask_index'],
outputs=['result'],
......
......@@ -22,9 +22,9 @@ TEST_CASE(attention_test)
std::vector<float> bias_v(2304, 1);
std::vector<int64_t> mask_index_v(2 * 384, 1);
migraphx::parameter_map pp;
pp["input"] = migraphx::argument(s_i, input_v.data());
pp["weights"] = migraphx::argument(s_w, weights_v.data());
pp["bias"] = migraphx::argument(s_b, bias_v.data());
pp["input"] = migraphx::argument(s_i, input_v.data());
pp["weights"] = migraphx::argument(s_w, weights_v.data());
pp["bias"] = migraphx::argument(s_b, bias_v.data());
pp["mask_index"] = migraphx::argument(s_m, mask_index_v.data());
auto result = p.eval(pp).back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment