Commit bb3a3e53 authored by turneram's avatar turneram
Browse files

Edit attention test to run in ort

parent 8bfea5f7
attention_test:ö
T
attention_test:…
c
input
weights
bias
mask_indexresult Attention_0" Attention*
num_heads  attention_testZ
num_heads  : com.microsoftattention_testZ
input


......@@ -21,7 +21,7 @@ mask_indexresult Attention_0" Attention*
€Z
mask_index



€b
result
......
......@@ -217,7 +217,7 @@ def attention_test():
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT,
[768, 2304])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64,
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT32,
[2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT,
[2, 384, 768])
......@@ -227,6 +227,7 @@ def attention_test():
outputs=['result'],
num_heads=12,
name="Attention_0")
node.domain = "com.microsoft"
return ([node], [input, weights, bias, mask_index], [result])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment