Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cac6c759
Commit
cac6c759
authored
Dec 13, 2023
by
Paul
Browse files
Merge
parents
4bde67c4
a60bdb67
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
138 additions
and
56 deletions
+138
-56
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-3
test/quantization.cpp
test/quantization.cpp
+19
-8
test/verify/batch_quant_dot_1.cpp
test/verify/batch_quant_dot_1.cpp
+13
-5
test/verify/batch_quant_dot_2.cpp
test/verify/batch_quant_dot_2.cpp
+13
-5
test/verify/batch_quant_dot_3.cpp
test/verify/batch_quant_dot_3.cpp
+6
-3
test/verify/batch_quant_dot_4.cpp
test/verify/batch_quant_dot_4.cpp
+6
-3
test/verify/batch_quant_dot_5.cpp
test/verify/batch_quant_dot_5.cpp
+6
-3
test/verify/main.cpp
test/verify/main.cpp
+15
-0
test/verify/quant_dot_3args_1.cpp
test/verify/quant_dot_3args_1.cpp
+13
-5
test/verify/quant_dot_3args_2.cpp
test/verify/quant_dot_3args_2.cpp
+12
-5
test/verify/quant_dot_3args_3.cpp
test/verify/quant_dot_3args_3.cpp
+11
-5
test/verify/quant_dot_3args_4.cpp
test/verify/quant_dot_3args_4.cpp
+12
-5
test/verify/quant_dot_3args_5.cpp
test/verify/quant_dot_3args_5.cpp
+10
-4
tools/api/api.cpp
tools/api/api.cpp
+2
-2
No files found.
test/py/onnx_backend_test.py
View file @
cac6c759
...
@@ -118,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
...
@@ -118,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test
.
exclude
(
r
'test_convtranspose_1d_cpu'
)
backend_test
.
exclude
(
r
'test_convtranspose_1d_cpu'
)
backend_test
.
exclude
(
r
'test_det_2d_cpu'
)
backend_test
.
exclude
(
r
'test_det_2d_cpu'
)
backend_test
.
exclude
(
r
'test_det_nd_cpu'
)
backend_test
.
exclude
(
r
'test_det_nd_cpu'
)
backend_test
.
exclude
(
r
'test_dynamicquantizelinear_cpu'
)
backend_test
.
exclude
(
r
'test_dynamicquantizelinear_max_adjusted_cpu'
)
backend_test
.
exclude
(
r
'test_dynamicquantizelinear_min_adjusted_cpu'
)
backend_test
.
exclude
(
r
'test_edge_pad_cpu'
)
backend_test
.
exclude
(
r
'test_edge_pad_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_diagonal_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_diagonal_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_matmul_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_matmul_cpu'
)
...
...
test/quantization.cpp
View file @
cac6c759
...
@@ -30,7 +30,7 @@
...
@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
EXPECT
(
p
==
qp
);
EXPECT
(
p
==
qp
);
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto
p
=
create_program
();
auto
p
=
create_program
();
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
test
::
throws
([
&
]
{
test
::
throws
([
&
]
{
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"add"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
});
});
}
}
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph)
...
@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph)
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
,
"dot"
},
quant_params
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p1
);
optimize_prog_int8
(
p1
);
auto
p2
=
create_int8_program
();
auto
p2
=
create_int8_program
();
...
...
test/verify/batch_quant_dot_1.cpp
View file @
cac6c759
...
@@ -24,19 +24,23 @@
...
@@ -24,19 +24,23 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
auto
tl1
=
mm
->
add_instruction
(
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto
tl2
=
mm
->
add_instruction
(
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_1
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_2.cpp
View file @
cac6c759
...
@@ -25,23 +25,31 @@
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
2
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_2
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_3.cpp
View file @
cac6c759
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
6
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
6
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_4.cpp
View file @
cac6c759
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
7
,
2
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
7
,
2
,
6
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_5.cpp
View file @
cac6c759
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
5
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
5
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/main.cpp
View file @
cac6c759
...
@@ -78,6 +78,16 @@ int main(int argc, const char* argv[])
...
@@ -78,6 +78,16 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim"
,
"test_split_single_dyn_dim"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
,
// these tests are disabled due issue of lossy downcast, see issue#2517
#if defined(__GNUC__) and !defined(__clang__)
"batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>"
,
"quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>"
,
"quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>"
,
#else
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
,
#endif
"test_block_reduce_small<3, migraphx::shape::int8_type>"
,
"test_block_reduce_small<3, migraphx::shape::int8_type>"
,
"test_block_reduce_small<4, migraphx::shape::int8_type>"
,
"test_block_reduce_small<4, migraphx::shape::int8_type>"
,
"test_block_reduce_small<8, migraphx::shape::int8_type>"
,
"test_block_reduce_small<8, migraphx::shape::int8_type>"
,
...
@@ -89,5 +99,10 @@ int main(int argc, const char* argv[])
...
@@ -89,5 +99,10 @@ int main(int argc, const char* argv[])
"test_block_reduce_small<128, migraphx::shape::int8_type>"
,
"test_block_reduce_small<128, migraphx::shape::int8_type>"
,
"test_block_reduce_small<129, migraphx::shape::int8_type>"
,
"test_block_reduce_small<129, migraphx::shape::int8_type>"
,
});
});
rv
.
disable_test_for
(
"gpu"
,
{
// These passes on MI300 but fails on others, same issue as CPU.
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
});
rv
.
run
(
argc
,
argv
);
rv
.
run
(
argc
,
argv
);
}
}
test/verify/quant_dot_3args_1.cpp
View file @
cac6c759
...
@@ -25,23 +25,31 @@
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
1
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
1
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_1
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_2.cpp
View file @
cac6c759
...
@@ -28,22 +28,29 @@
...
@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_2
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_3.cpp
View file @
cac6c759
...
@@ -28,22 +28,28 @@
...
@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
2
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
2
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_3
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_3
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_4.cpp
View file @
cac6c759
...
@@ -28,15 +28,18 @@
...
@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_4
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_4
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_5.cpp
View file @
cac6c759
...
@@ -28,14 +28,17 @@
...
@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
6
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
6
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
6
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
}
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_5
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_5
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
tools/api/api.cpp
View file @
cac6c759
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct
quantize_int8_options
struct
quantize_int8_options
{
{
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
std
::
string
>
op_names
=
{};
std
::
unordered_set
<
std
::
string
>
op_names
=
{};
};
};
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
{
{
options
.
op_names
.
push_back
(
name
);
options
.
op_names
.
insert
(
name
);
}
}
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment