Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4926f035
Commit
4926f035
authored
Dec 06, 2023
by
Umang Yadav
Browse files
Merge branch 'quant_gemm_fp8' into mobilenet_fp8
parents
acd9bd3e
1ce916c4
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
149 additions
and
58 deletions
+149
-58
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+16
-0
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+8
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+5
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
src/targets/ref/target.cpp
src/targets/ref/target.cpp
+15
-11
test/verify/batch_quant_dot_1.cpp
test/verify/batch_quant_dot_1.cpp
+13
-5
test/verify/batch_quant_dot_2.cpp
test/verify/batch_quant_dot_2.cpp
+13
-5
test/verify/batch_quant_dot_3.cpp
test/verify/batch_quant_dot_3.cpp
+6
-3
test/verify/batch_quant_dot_4.cpp
test/verify/batch_quant_dot_4.cpp
+6
-3
test/verify/batch_quant_dot_5.cpp
test/verify/batch_quant_dot_5.cpp
+6
-3
test/verify/quant_dot_3args_1.cpp
test/verify/quant_dot_3args_1.cpp
+13
-5
test/verify/quant_dot_3args_2.cpp
test/verify/quant_dot_3args_2.cpp
+12
-5
test/verify/quant_dot_3args_3.cpp
test/verify/quant_dot_3args_3.cpp
+11
-5
test/verify/quant_dot_3args_4.cpp
test/verify/quant_dot_3args_4.cpp
+12
-5
test/verify/quant_dot_3args_5.cpp
test/verify/quant_dot_3args_5.cpp
+10
-4
No files found.
src/eliminate_data_type.cpp
View file @
4926f035
...
...
@@ -120,6 +120,22 @@ void eliminate_data_type::apply(module& m) const
if
(
contains
(
unsupported_ops
,
"all"
)
or
contains
(
unsupported_ops
,
ins
->
name
()))
insert_convert_to_supported_type
(
m
,
ins
,
target_type
,
unsupported_types
);
}
// remove nested converts
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"convert"
)
{
auto
convert_input
=
ins
->
inputs
().
front
();
while
(
convert_input
->
name
()
==
"convert"
)
{
convert_input
=
convert_input
->
inputs
().
front
();
}
if
(
convert_input
->
get_shape
()
==
ins
->
get_shape
())
{
m
.
replace_instruction
(
ins
,
convert_input
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/op/quant_dot.hpp
View file @
4926f035
...
...
@@ -44,9 +44,11 @@ struct quant_dot
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
suppported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
suppported_types
,
t
))
{
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t
and fp8e4m3fnuz_type
"
);
}
if
(
not
std
::
all_of
(
...
...
@@ -73,6 +75,10 @@ struct quant_dot
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
t
==
shape
::
fp8e4m3fnuz_type
)
{
return
{
shape
::
float_type
,
out_lens
};
}
// else int8 gemm
return
{
shape
::
int32_type
,
out_lens
};
}
};
...
...
src/simplify_reshapes.cpp
View file @
4926f035
...
...
@@ -183,6 +183,11 @@ struct find_nested_convert
auto
x
=
ins
->
inputs
().
front
();
auto
input
=
x
->
inputs
().
front
();
while
(
input
->
name
()
==
"convert"
)
{
input
=
input
->
inputs
().
front
();
}
if
(
ins
->
get_shape
()
!=
input
->
get_shape
())
return
;
...
...
src/targets/gpu/gemm_impl.cpp
View file @
4926f035
...
...
@@ -195,7 +195,7 @@ struct gemm_impl
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg
_type
;
output_type
=
get
_type
(
input_shapes
[
2
].
type
())
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
4926f035
...
...
@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
or
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
{
gemm_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
...
...
src/targets/gpu/target.cpp
View file @
4926f035
...
...
@@ -110,6 +110,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
unsupported_fp8_ops
.
insert
(
"quant_dot"
);
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
...
...
src/targets/ref/target.cpp
View file @
4926f035
...
...
@@ -24,6 +24,7 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/ref/lowering.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp>
...
...
@@ -42,9 +43,12 @@ std::string target::name() const { return "ref"; }
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
,
const
compile_options
&
)
const
{
return
{
normalize_ops
{},
return
{
normalize_ops
{},
eliminate_pad
{},
dead_code_elimination
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
{
"quant_dot"
}},
dead_code_elimination
{},
insert_pad
{},
dead_code_elimination
{},
rewrite_rnn
{},
...
...
test/verify/batch_quant_dot_1.cpp
View file @
4926f035
...
...
@@ -24,19 +24,23 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
...
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
}
};
template
struct
batch_quant_dot_1
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_2.cpp
View file @
4926f035
...
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
}
};
template
struct
batch_quant_dot_2
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_3.cpp
View file @
4926f035
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
6
,
7
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
6
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return
p
;
}
};
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_4.cpp
View file @
4926f035
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
7
,
2
,
6
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
7
,
2
,
6
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return
p
;
}
};
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_5.cpp
View file @
4926f035
...
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
5
,
7
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
5
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return
p
;
}
};
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_dot_3args_1.cpp
View file @
4926f035
...
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
1
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
1
});
return
p
;
}
};
template
struct
quant_dot_3args_1
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_2.cpp
View file @
4926f035
...
...
@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
}
};
template
struct
quant_dot_3args_2
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_3.cpp
View file @
4926f035
...
...
@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
2
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
2
},
CType
{
3
});
return
p
;
}
};
template
struct
quant_dot_3args_3
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_3
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_4.cpp
View file @
4926f035
...
...
@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
...
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
}
};
template
struct
quant_dot_3args_4
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_4
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_5.cpp
View file @
4926f035
...
...
@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
<
DType
,
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
6
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m1_shape
{
dtype
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
6
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
...
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
}
);
return
p
;
}
};
template
struct
quant_dot_3args_5
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_5
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment