Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cac6c759
"mmdet/vscode:/vscode.git/clone" did not exist on "2017c81e0ce0dd3eea63dc9cf56c3e535ae67ec1"
Commit
cac6c759
authored
Dec 13, 2023
by
Paul
Browse files
Merge
parents
4bde67c4
a60bdb67
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
138 additions
and
56 deletions
+138
-56
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-3
test/quantization.cpp
test/quantization.cpp
+19
-8
test/verify/batch_quant_dot_1.cpp
test/verify/batch_quant_dot_1.cpp
+13
-5
test/verify/batch_quant_dot_2.cpp
test/verify/batch_quant_dot_2.cpp
+13
-5
test/verify/batch_quant_dot_3.cpp
test/verify/batch_quant_dot_3.cpp
+6
-3
test/verify/batch_quant_dot_4.cpp
test/verify/batch_quant_dot_4.cpp
+6
-3
test/verify/batch_quant_dot_5.cpp
test/verify/batch_quant_dot_5.cpp
+6
-3
test/verify/main.cpp
test/verify/main.cpp
+15
-0
test/verify/quant_dot_3args_1.cpp
test/verify/quant_dot_3args_1.cpp
+13
-5
test/verify/quant_dot_3args_2.cpp
test/verify/quant_dot_3args_2.cpp
+12
-5
test/verify/quant_dot_3args_3.cpp
test/verify/quant_dot_3args_3.cpp
+11
-5
test/verify/quant_dot_3args_4.cpp
test/verify/quant_dot_3args_4.cpp
+12
-5
test/verify/quant_dot_3args_5.cpp
test/verify/quant_dot_3args_5.cpp
+10
-4
tools/api/api.cpp
tools/api/api.cpp
+2
-2
No files found.
test/py/onnx_backend_test.py
View file @
cac6c759
...
@@ -118,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
...
@@ -118,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test
.
exclude
(
r
'test_convtranspose_1d_cpu'
)
backend_test
.
exclude
(
r
'test_convtranspose_1d_cpu'
)
backend_test
.
exclude
(
r
'test_det_2d_cpu'
)
backend_test
.
exclude
(
r
'test_det_2d_cpu'
)
backend_test
.
exclude
(
r
'test_det_nd_cpu'
)
backend_test
.
exclude
(
r
'test_det_nd_cpu'
)
backend_test
.
exclude
(
r
'test_dynamicquantizelinear_cpu'
)
backend_test
.
exclude
(
r
'test_dynamicquantizelinear_max_adjusted_cpu'
)
backend_test
.
exclude
(
r
'test_dynamicquantizelinear_min_adjusted_cpu'
)
backend_test
.
exclude
(
r
'test_edge_pad_cpu'
)
backend_test
.
exclude
(
r
'test_edge_pad_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_diagonal_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_diagonal_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_matmul_cpu'
)
backend_test
.
exclude
(
r
'test_einsum_batch_matmul_cpu'
)
...
...
test/quantization.cpp
View file @
cac6c759
...
@@ -30,7 +30,7 @@
...
@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
EXPECT
(
p
==
qp
);
EXPECT
(
p
==
qp
);
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto
p
=
create_program
();
auto
p
=
create_program
();
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
test
::
throws
([
&
]
{
test
::
throws
([
&
]
{
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"add"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
});
});
}
}
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph)
...
@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph)
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
,
"dot"
},
quant_params
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p1
);
optimize_prog_int8
(
p1
);
auto
p2
=
create_int8_program
();
auto
p2
=
create_int8_program
();
...
...
test/verify/batch_quant_dot_1.cpp
View file @
cac6c759
...
@@ -24,19 +24,23 @@
...
@@ -24,19 +24,23 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
mm
->
add_instruction
(
auto
tl1
=
mm
->
add_instruction
(
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
...
@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto
tl2
=
mm
->
add_instruction
(
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_1
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_2.cpp
View file @
cac6c759
...
@@ -25,23 +25,31 @@
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
>
template
<
typename
DType
,
typename
CType
>
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
2
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
{};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
7
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
{};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_2
<
int8_t
,
int32_t
>;
template
struct
batch_quant_dot_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/batch_quant_dot_3.cpp
View file @
cac6c759
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_3
:
verify_program
<
batch_quant_dot_3
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
2
,
6
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
6
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
6
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
...
@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_4.cpp
View file @
cac6c759
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_4
:
verify_program
<
batch_quant_dot_4
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
4
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
7
,
2
,
6
,
3
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
7
,
2
,
6
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
...
@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_4
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/batch_quant_dot_5.cpp
View file @
cac6c759
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
batch_quant_dot_5
:
verify_program
<
batch_quant_dot_5
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
3
,
2
,
7
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
3
,
2
,
5
,
7
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
2
,
5
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
...
@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return
p
;
return
p
;
}
}
};
};
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
int8_type
>;
template
struct
batch_quant_dot_5
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/main.cpp
View file @
cac6c759
...
@@ -78,6 +78,16 @@ int main(int argc, const char* argv[])
...
@@ -78,6 +78,16 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim"
,
"test_split_single_dyn_dim"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
,
// these tests are disabled due issue of lossy downcast, see issue#2517
#if defined(__GNUC__) and !defined(__clang__)
"batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>"
,
"quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>"
,
"quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>"
,
#else
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
,
#endif
"test_block_reduce_small<3, migraphx::shape::int8_type>"
,
"test_block_reduce_small<3, migraphx::shape::int8_type>"
,
"test_block_reduce_small<4, migraphx::shape::int8_type>"
,
"test_block_reduce_small<4, migraphx::shape::int8_type>"
,
"test_block_reduce_small<8, migraphx::shape::int8_type>"
,
"test_block_reduce_small<8, migraphx::shape::int8_type>"
,
...
@@ -89,5 +99,10 @@ int main(int argc, const char* argv[])
...
@@ -89,5 +99,10 @@ int main(int argc, const char* argv[])
"test_block_reduce_small<128, migraphx::shape::int8_type>"
,
"test_block_reduce_small<128, migraphx::shape::int8_type>"
,
"test_block_reduce_small<129, migraphx::shape::int8_type>"
,
"test_block_reduce_small<129, migraphx::shape::int8_type>"
,
});
});
rv
.
disable_test_for
(
"gpu"
,
{
// These passes on MI300 but fails on others, same issue as CPU.
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>"
,
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
});
rv
.
run
(
argc
,
argv
);
rv
.
run
(
argc
,
argv
);
}
}
test/verify/quant_dot_3args_1.cpp
View file @
cac6c759
...
@@ -25,23 +25,31 @@
...
@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
1
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
1
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_1
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_1
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_2.cpp
View file @
cac6c759
...
@@ -28,22 +28,29 @@
...
@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l1
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
1
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
l2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
1
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_2
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_2
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_3.cpp
View file @
cac6c759
...
@@ -28,22 +28,28 @@
...
@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
2
,
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
l1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
2
},
CType
{
3
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_3
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_3
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_4.cpp
View file @
cac6c759
...
@@ -28,15 +28,18 @@
...
@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
auto
ctype
=
migraphx
::
shape
::
get_type
<
CType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
ctype
,
{
2
,
7
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
...
@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
auto
l3
=
mm
->
add_parameter
(
"c"
,
m3_shape
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
3
,
2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
,
l3
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
},
CType
{
2
});
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_4
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_4
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
test/verify/quant_dot_3args_5.cpp
View file @
cac6c759
...
@@ -28,14 +28,17 @@
...
@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
>
template
<
typename
DType
,
typename
CType
>
struct
quant_dot_3args_5
:
verify_program
<
quant_dot_3args_5
<
DType
,
CType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
6
,
2
}};
auto
dtype
=
migraphx
::
shape
::
get_type
<
DType
>
();
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
6
}};
migraphx
::
shape
m1_shape
{
dtype
,
{
6
,
2
}};
migraphx
::
shape
m2_shape
{
dtype
,
{
7
,
6
}};
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
auto
tl1
=
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
...
@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
auto
tl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
l2
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
3
);
migraphx
::
add_apply_alpha_beta
(
*
mm
,
{
tl1
,
tl2
},
migraphx
::
make_op
(
"quant_dot"
),
CType
{
3
}
);
return
p
;
return
p
;
}
}
};
};
template
struct
quant_dot_3args_5
<
int8_t
,
int32_t
>;
template
struct
quant_dot_3args_5
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
float
>;
tools/api/api.cpp
View file @
cac6c759
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct
quantize_int8_options
struct
quantize_int8_options
{
{
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
std
::
string
>
op_names
=
{};
std
::
unordered_set
<
std
::
string
>
op_names
=
{};
};
};
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
{
{
options
.
op_names
.
push_back
(
name
);
options
.
op_names
.
insert
(
name
);
}
}
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment