Commit bb3a3e53 authored by turneram's avatar turneram
Browse files

Edit attention test to run in ort

parent 8bfea5f7
attention_test:ö attention_test:…
T c
input input
weights weights
bias bias
mask_indexresult Attention_0" Attention* mask_indexresult Attention_0" Attention*
num_heads  attention_testZ num_heads  : com.microsoftattention_testZ
input input
 
 
...@@ -21,7 +21,7 @@ mask_indexresult Attention_0" Attention* ...@@ -21,7 +21,7 @@ mask_indexresult Attention_0" Attention*
€Z €Z
mask_index mask_index
 
 
€b €b
result result
......
...@@ -217,7 +217,7 @@ def attention_test(): ...@@ -217,7 +217,7 @@ def attention_test():
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT, weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT,
[768, 2304]) [768, 2304])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304]) bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64, mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT32,
[2, 384]) [2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT, result = helper.make_tensor_value_info('result', TensorProto.FLOAT,
[2, 384, 768]) [2, 384, 768])
...@@ -227,6 +227,7 @@ def attention_test(): ...@@ -227,6 +227,7 @@ def attention_test():
outputs=['result'], outputs=['result'],
num_heads=12, num_heads=12,
name="Attention_0") name="Attention_0")
node.domain = "com.microsoft"
return ([node], [input, weights, bias, mask_index], [result]) return ([node], [input, weights, bias, mask_index], [result])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment