Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4e07dfcc
Commit
4e07dfcc
authored
Dec 03, 2023
by
Umang Yadav
Browse files
revert some changes
parent
050184cb
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
42 additions
and
96 deletions
+42
-96
test/verify/batch_quant_dot_1.cpp
test/verify/batch_quant_dot_1.cpp
+5
-13
test/verify/batch_quant_dot_2.cpp
test/verify/batch_quant_dot_2.cpp
+4
-7
test/verify/batch_quant_dot_3.cpp
test/verify/batch_quant_dot_3.cpp
+3
-6
test/verify/batch_quant_dot_4.cpp
test/verify/batch_quant_dot_4.cpp
+3
-6
test/verify/batch_quant_dot_5.cpp
test/verify/batch_quant_dot_5.cpp
+3
-6
test/verify/quant_dot_3args_1.cpp
test/verify/quant_dot_3args_1.cpp
+5
-13
test/verify/quant_dot_3args_2.cpp
test/verify/quant_dot_3args_2.cpp
+5
-12
test/verify/quant_dot_3args_3.cpp
test/verify/quant_dot_3args_3.cpp
+5
-11
test/verify/quant_dot_3args_4.cpp
test/verify/quant_dot_3args_4.cpp
+5
-12
test/verify/quant_dot_3args_5.cpp
test/verify/quant_dot_3args_5.cpp
+4
-10
No files found.
test/verify/batch_quant_dot_1.cpp
View file @
4e07dfcc
...
@@ -24,23 +24,19 @@
...
@@ -24,23 +24,19 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
auto
tl1
=
mm
->
add_instruction
(
...
@@ -49,11 +45,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
...
@@ -49,11 +45,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
auto
tl2
=
mm
->
add_instruction
(
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_1
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_2.cpp
View file @
4e07dfcc
...
@@ -28,16 +28,15 @@
...
@@ -28,16 +28,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
migraphx
::
shape
::
type_t
DType
,
migraphx
::
shape
::
type_t
CType
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
CT
ype
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_t
ype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -46,5 +45,3 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
...
@@ -46,5 +45,3 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_2
<
migraphx
::
shape
::
int8_type
,
migraphx
::
shape
::
int32_type
>;
template
struct
batch_quant_dot_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
,
migraphx
::
shape
::
float_type
>;
test/verify/batch_quant_dot_3.cpp
View file @
4e07dfcc
...
@@ -27,15 +27,14 @@
...
@@ -27,15 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
6
,
7
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
6
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -43,5 +42,3 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
...
@@ -43,5 +42,3 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_4.cpp
View file @
4e07dfcc
...
@@ -27,15 +27,14 @@
...
@@ -27,15 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
7
,
2
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
7
,
2
,
6
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -47,5 +46,3 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
...
@@ -47,5 +46,3 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_5.cpp
View file @
4e07dfcc
...
@@ -27,15 +27,14 @@
...
@@ -27,15 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
5
,
7
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
5
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -49,5 +48,3 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
...
@@ -49,5 +48,3 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_dot_3args_1.cpp
View file @
4e07dfcc
...
@@ -25,31 +25,23 @@
...
@@ -25,31 +25,23 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
1
);
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
1
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_1
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_2.cpp
View file @
4e07dfcc
...
@@ -28,29 +28,22 @@
...
@@ -28,29 +28,22 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_2
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_3.cpp
View file @
4e07dfcc
...
@@ -28,28 +28,22 @@
...
@@ -28,28 +28,22 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
2
,
3
);
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
2
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_3
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_3
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_4.cpp
View file @
4e07dfcc
...
@@ -28,18 +28,15 @@
...
@@ -28,18 +28,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -48,11 +45,7 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
...
@@ -48,11 +45,7 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_4
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_4
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_5.cpp
View file @
4e07dfcc
...
@@ -28,17 +28,14 @@
...
@@ -28,17 +28,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
6
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
6
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -46,10 +43,7 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
...
@@ -46,10 +43,7 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
}
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
3
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_5
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_5
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment