Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
48187e79
Commit
48187e79
authored
May 20, 2022
by
turneram
Browse files
Formatting
parent
6202ea15
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
8 deletions
+12
-8
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+9
-5
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+3
-3
No files found.
test/onnx/gen_onnx.py
View file @
48187e79
...
@@ -190,12 +190,16 @@ def atanh_test():
...
@@ -190,12 +190,16 @@ def atanh_test():
@
onnx_test
@
onnx_test
def
attention_test
():
def
attention_test
():
input
=
helper
.
make_tensor_value_info
(
'input'
,
TensorProto
.
FLOAT
,
[
2
,
384
,
768
])
input
=
helper
.
make_tensor_value_info
(
'input'
,
TensorProto
.
FLOAT
,
weights
=
helper
.
make_tensor_value_info
(
'weights'
,
TensorProto
.
FLOAT
,
[
768
,
2304
])
[
2
,
384
,
768
])
weights
=
helper
.
make_tensor_value_info
(
'weights'
,
TensorProto
.
FLOAT
,
[
768
,
2304
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2304
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2304
])
mask_index
=
helper
.
make_tensor_value_info
(
'mask_index'
,
TensorProto
.
INT64
,
[
2
,
384
])
mask_index
=
helper
.
make_tensor_value_info
(
'mask_index'
,
TensorProto
.
INT64
,
result
=
helper
.
make_tensor_value_info
(
'result'
,
TensorProto
.
FLOAT
,
[
2
,
384
,
768
])
[
2
,
384
])
result
=
helper
.
make_tensor_value_info
(
'result'
,
TensorProto
.
FLOAT
,
[
2
,
384
,
768
])
node
=
helper
.
make_node
(
'Attention'
,
node
=
helper
.
make_node
(
'Attention'
,
inputs
=
[
'input'
,
'weights'
,
'bias'
,
'mask_index'
],
inputs
=
[
'input'
,
'weights'
,
'bias'
,
'mask_index'
],
outputs
=
[
'result'
],
outputs
=
[
'result'
],
...
...
test/onnx/verify_onnx.cpp
View file @
48187e79
...
@@ -22,9 +22,9 @@ TEST_CASE(attention_test)
...
@@ -22,9 +22,9 @@ TEST_CASE(attention_test)
std
::
vector
<
float
>
bias_v
(
2304
,
1
);
std
::
vector
<
float
>
bias_v
(
2304
,
1
);
std
::
vector
<
int64_t
>
mask_index_v
(
2
*
384
,
1
);
std
::
vector
<
int64_t
>
mask_index_v
(
2
*
384
,
1
);
migraphx
::
parameter_map
pp
;
migraphx
::
parameter_map
pp
;
pp
[
"input"
]
=
migraphx
::
argument
(
s_i
,
input_v
.
data
());
pp
[
"input"
]
=
migraphx
::
argument
(
s_i
,
input_v
.
data
());
pp
[
"weights"
]
=
migraphx
::
argument
(
s_w
,
weights_v
.
data
());
pp
[
"weights"
]
=
migraphx
::
argument
(
s_w
,
weights_v
.
data
());
pp
[
"bias"
]
=
migraphx
::
argument
(
s_b
,
bias_v
.
data
());
pp
[
"bias"
]
=
migraphx
::
argument
(
s_b
,
bias_v
.
data
());
pp
[
"mask_index"
]
=
migraphx
::
argument
(
s_m
,
mask_index_v
.
data
());
pp
[
"mask_index"
]
=
migraphx
::
argument
(
s_m
,
mask_index_v
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
auto
result
=
p
.
eval
(
pp
).
back
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment