Commit f14e2a44 authored by turneram's avatar turneram
Browse files

Add num_heads to attention node

parent eff3d2d3
attention_test:ä attention_test:ö
B T
input input
weights weights
bias bias
mask_indexresult Attention_0" Attentionattention_testZ mask_indexresult Attention_0" Attention*
num_heads  attention_testZ
input input
 
 
......
...@@ -203,6 +203,8 @@ def attention_test(): ...@@ -203,6 +203,8 @@ def attention_test():
node = helper.make_node('Attention', node = helper.make_node('Attention',
inputs=['input', 'weights', 'bias', 'mask_index'], inputs=['input', 'weights', 'bias', 'mask_index'],
outputs=['result'], outputs=['result'],
num_heads=12,
name="Attention_0") name="Attention_0")
return ([node], [input, weights, bias, mask_index], [result]) return ([node], [input, weights, bias, mask_index], [result])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment