Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4a39a0f7
Commit
4a39a0f7
authored
Oct 11, 2021
by
Shucai Xiao
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test
parents
5564172e
bb827865
Changes
542
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
106 additions
and
47 deletions
+106
-47
test/verify/batch_quant_dot_4.cpp
test/verify/batch_quant_dot_4.cpp
+25
-0
test/verify/batch_quant_dot_5.cpp
test/verify/batch_quant_dot_5.cpp
+27
-0
test/verify/gemm_2args_bmv.cpp
test/verify/gemm_2args_bmv.cpp
+1
-1
test/verify/gemm_2args_mm_1.cpp
test/verify/gemm_2args_mm_1.cpp
+4
-4
test/verify/gemm_2args_mm_2.cpp
test/verify/gemm_2args_mm_2.cpp
+4
-4
test/verify/gemm_2args_mm_3.cpp
test/verify/gemm_2args_mm_3.cpp
+3
-3
test/verify/gemm_2args_mm_4.cpp
test/verify/gemm_2args_mm_4.cpp
+3
-3
test/verify/gemm_2args_mm_5.cpp
test/verify/gemm_2args_mm_5.cpp
+1
-1
test/verify/gemm_2args_mm_6.cpp
test/verify/gemm_2args_mm_6.cpp
+2
-2
test/verify/gemm_2args_mm_7.cpp
test/verify/gemm_2args_mm_7.cpp
+1
-1
test/verify/gemm_2args_vbm.cpp
test/verify/gemm_2args_vbm.cpp
+1
-1
test/verify/gemm_2args_vv.cpp
test/verify/gemm_2args_vv.cpp
+2
-2
test/verify/gemm_multi_3args.cpp
test/verify/gemm_multi_3args.cpp
+2
-3
test/verify/gemm_multi_3args_alpha0.cpp
test/verify/gemm_multi_3args_alpha0.cpp
+2
-4
test/verify/gemm_multi_3args_beta0.cpp
test/verify/gemm_multi_3args_beta0.cpp
+2
-3
test/verify/gemm_multi_3args_c25.cpp
test/verify/gemm_multi_3args_c25.cpp
+2
-3
test/verify/gemm_multi_transpose.cpp
test/verify/gemm_multi_transpose.cpp
+6
-5
test/verify/main.cpp
test/verify/main.cpp
+9
-0
test/verify/quant_dot_3args_1.cpp
test/verify/quant_dot_3args_1.cpp
+2
-1
test/verify/quant_dot_3args_2.cpp
test/verify/quant_dot_3args_2.cpp
+7
-6
No files found.
test/verify/batch_quant_dot_4.cpp
0 → 100644
View file @
4a39a0f7
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
2
,
6
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
3
,
0
,
1
,
2
}}}),
l1
);
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
3
,
1
,
2
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
tl1
,
tl2
);
return
p
;
}
};
test/verify/batch_quant_dot_5.cpp
0 → 100644
View file @
4a39a0f7
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
5
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l1
);
auto
sl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
tl1
,
tl1
);
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
auto
sl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
tl2
,
tl2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
sl1
,
sl2
);
return
p
;
}
};
test/verify/gemm_2args_bmv.cpp
View file @
4a39a0f7
...
...
@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
auto
bul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
5
,
1
}}}),
ul2
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
,
1
}}}),
ul2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
bul2
);
...
...
test/verify/gemm_2args_mm_1.cpp
View file @
4a39a0f7
...
...
@@ -12,10 +12,10 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
4
}}}),
l2
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
4
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
bl2
);
...
...
test/verify/gemm_2args_mm_2.cpp
View file @
4a39a0f7
...
...
@@ -12,10 +12,10 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
4
}}}),
l2
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
4
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
bl2
);
...
...
test/verify/gemm_2args_mm_3.cpp
View file @
4a39a0f7
...
...
@@ -12,9 +12,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
3
,
2
,
3
}}}),
l1
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bl1
,
l2
);
...
...
test/verify/gemm_2args_mm_4.cpp
View file @
4a39a0f7
...
...
@@ -12,9 +12,9 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
3
,
2
,
3
}}}),
l1
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bl1
,
l2
);
...
...
test/verify/gemm_2args_mm_5.cpp
View file @
4a39a0f7
...
...
@@ -14,7 +14,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bl1
,
l2
);
...
...
test/verify/gemm_2args_mm_6.cpp
View file @
4a39a0f7
...
...
@@ -14,10 +14,10 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
3
,
4
}}}),
l2
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
3
,
4
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bl1
,
bl2
);
...
...
test/verify/gemm_2args_mm_7.cpp
View file @
4a39a0f7
...
...
@@ -14,7 +14,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
2
,
3
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bl1
,
l2
);
...
...
test/verify/gemm_2args_vbm.cpp
View file @
4a39a0f7
...
...
@@ -15,7 +15,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
ul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
l1
);
auto
bul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
{
2
,
2
,
1
,
5
}}}),
ul1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
2
,
1
,
5
}}}),
ul1
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
...
test/verify/gemm_2args_vv.cpp
View file @
4a39a0f7
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
...
...
@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
float
alpha
=
0.23
f
;
auto
res
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
,
{{
"alpha"
,
alpha
}}),
ul1
,
ul2
);
auto
res
=
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
ul1
,
ul2
},
migraphx
::
make_op
(
"dot"
),
alpha
);
auto
sres
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
res
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
sres
);
...
...
test/verify/gemm_multi_3args.cpp
View file @
4a39a0f7
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
...
...
@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
float
alpha
=
0.35
;
float
beta
=
0.41
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
l2
,
l3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"dot"
),
alpha
,
beta
);
return
p
;
}
};
test/verify/gemm_multi_3args_alpha0.cpp
View file @
4a39a0f7
...
...
@@ -3,7 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_multi_3args_alpha0
:
verify_program
<
gemm_multi_3args_alpha0
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float
alpha
=
0.0
f
;
float
beta
=
1.0
f
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
l2
,
l3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"dot"
),
alpha
,
beta
);
return
p
;
}
};
test/verify/gemm_multi_3args_beta0.cpp
View file @
4a39a0f7
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
...
...
@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
l2
,
l3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"dot"
),
alpha
,
beta
);
return
p
;
}
};
test/verify/gemm_multi_3args_c25.cpp
View file @
4a39a0f7
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
...
...
@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
float
alpha
=
0.35
;
float
beta
=
0.41
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
l2
,
l3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"dot"
),
alpha
,
beta
);
return
p
;
}
};
test/verify/gemm_multi_transpose.cpp
View file @
4a39a0f7
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
...
...
@@ -12,14 +13,14 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"dims"
,
{
1
,
0
,
2
}}}),
l2
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
,
2
}}}),
l2
);
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
tl2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
},
migraphx
::
make_op
(
"dot"
),
alpha
,
beta
);
return
p
;
}
};
test/verify/main.cpp
100755 → 100644
View file @
4a39a0f7
...
...
@@ -45,5 +45,14 @@ int main(int argc, const char* argv[])
run_verify
rv
;
rv
.
add_validation_for
(
"gpu"
,
&
validate_gpu
);
rv
.
disable_test_for
(
"cpu"
,
{
"test_if_lp"
,
"test_if_param"
,
"test_if_literal"
});
rv
.
disable_test_for
(
"gpu"
,
{
"batch_quant_dot_2"
,
"batch_quant_dot_3"
,
"batch_quant_dot_5"
,
"quant_dot_3args_1"
,
"quant_dot_3args_2"
,
"quant_dot_3args_3"
,
"quant_dot_3args_4"
,
"quant_dot_3args_5"
});
rv
.
run
(
argc
,
argv
);
}
test/verify/quant_dot_3args_1.cpp
View file @
4a39a0f7
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
...
...
@@ -17,7 +18,7 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
m
m
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
l
1
,
l2
,
l3
);
m
igraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
1
);
return
p
;
}
};
test/verify/quant_dot_3args_2.cpp
View file @
4a39a0f7
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
...
...
@@ -14,12 +15,12 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"dims"
,
{
1
,
0
}}}),
l1
);
auto
l2
=
mm
->
add_
parameter
(
"b"
,
m2_shape
);
auto
l
3
=
mm
->
add_parameter
(
"
c
"
,
m
3
_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
,
{{
"alpha"
,
1
},
{
"beta"
,
3
}}
),
tl
1
,
l2
,
l
3
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_
instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
auto
l
2
=
mm
->
add_parameter
(
"
b
"
,
m
2
_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
return
p
;
}
};
Prev
1
…
19
20
21
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment