Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
370d18c1
Commit
370d18c1
authored
Dec 03, 2023
by
Umang Yadav
Browse files
add quant_conv tests
parent
4e07dfcc
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
60 additions
and
26 deletions
+60
-26
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+11
-5
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+5
-3
test/verify/quant_conv.cpp
test/verify/quant_conv.cpp
+7
-3
test/verify/quant_conv_1.cpp
test/verify/quant_conv_1.cpp
+7
-3
test/verify/quant_conv_1d.cpp
test/verify/quant_conv_1d.cpp
+8
-3
test/verify/quant_conv_2.cpp
test/verify/quant_conv_2.cpp
+9
-3
test/verify/quant_conv_padding.cpp
test/verify/quant_conv_padding.cpp
+7
-3
test/verify/quant_conv_padding_stride.cpp
test/verify/quant_conv_padding_stride.cpp
+6
-3
No files found.
src/include/migraphx/op/quant_convolution.hpp
View file @
370d18c1
...
...
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include "migraphx/shape.hpp"
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
...
...
@@ -87,11 +88,13 @@ struct quant_convolution
}
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
supported_types
,
t
))
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type"
);
}
t
=
shape
::
int32_type
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
auto
padding_size
=
padding
.
size
();
...
...
@@ -107,8 +110,11 @@ struct quant_convolution
stride
[
i
]
+
1
)));
}
return
inputs
[
0
].
with_lens
(
t
,
output_lens
);
if
(
t
==
shape
::
int8_type
)
{
return
inputs
[
0
].
with_lens
(
shape
::
int32_type
,
output_lens
);
}
// else fp8 conv
return
inputs
[
0
].
with_lens
(
shape
::
float_type
,
output_lens
);
}
size_t
kdims
()
const
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
370d18c1
...
...
@@ -214,6 +214,7 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
auto
input_arg_t
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
...
...
@@ -223,6 +224,8 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
float_type
and
input_arg_t
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
...
...
@@ -403,8 +406,7 @@ struct find_mlir_standalone_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
gemm_based_op
=
r
.
result
;
//
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8/fp8 types
if
(
std
::
any_of
(
gemm_based_op
->
inputs
().
begin
(),
gemm_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
fp8e4m3fnuz_type
},
...
...
@@ -530,7 +532,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
int8
)},
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
all
)},
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
#else
(
void
)
mpm
;
...
...
test/verify/quant_conv.cpp
View file @
370d18c1
...
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv
:
verify_program
<
quant_conv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv
:
verify_program
<
quant_conv
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_1.cpp
View file @
370d18c1
...
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_1
:
verify_program
<
quant_conv_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_1
:
verify_program
<
quant_conv_1
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv_1
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_1d.cpp
View file @
370d18c1
...
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_1d
:
verify_program
<
quant_conv_1d
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_1d
:
verify_program
<
quant_conv_1d
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
...
...
@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return
p
;
}
};
template
struct
quant_conv_1d
<
migraphx
::
shape
::
int8_type
>;
// MLIR 1D convolution is not supported in MIGraphX yet.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
test/verify/quant_conv_2.cpp
View file @
370d18c1
...
...
@@ -27,17 +27,23 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_2
:
verify_program
<
quant_conv_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_2
:
verify_program
<
quant_conv_2
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
16
,
16
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
16
,
16
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
16
,
16
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
16
,
16
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv_2
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_padding.cpp
View file @
370d18c1
...
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_padding
:
verify_program
<
quant_conv_padding
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_padding
:
verify_program
<
quant_conv_padding
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
1
,
1
}}}),
...
...
@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return
p
;
}
};
template
struct
quant_conv_padding
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_padding
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_padding_stride.cpp
View file @
370d18c1
...
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_padding_stride
:
verify_program
<
quant_conv_padding_stride
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_padding_stride
:
verify_program
<
quant_conv_padding_stride
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}}}),
...
...
@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return
p
;
}
};
template
struct
quant_conv_padding_stride
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_padding_stride
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment