Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6a72e8fc
"vscode:/vscode.git/clone" did not exist on "7d5a155e4aa258b0db1649bdcbf29477bf490ac2"
Unverified
Commit
6a72e8fc
authored
Dec 06, 2023
by
Umang Yadav
Committed by
GitHub
Dec 06, 2023
Browse files
FP8 2D forward convolution using rocMLIR (#2507)
parent
a09dc502
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
132 additions
and
64 deletions
+132
-64
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+11
-5
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+6
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+18
-6
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+2
-0
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-0
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+1
-2
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+8
-0
test/verify/main.cpp
test/verify/main.cpp
+0
-1
test/verify/quant_conv.cpp
test/verify/quant_conv.cpp
+7
-3
test/verify/quant_conv_1.cpp
test/verify/quant_conv_1.cpp
+7
-3
test/verify/quant_conv_1d.cpp
test/verify/quant_conv_1d.cpp
+8
-3
test/verify/quant_conv_2.cpp
test/verify/quant_conv_2.cpp
+7
-3
test/verify/quant_conv_padding.cpp
test/verify/quant_conv_padding.cpp
+7
-3
test/verify/quant_conv_padding_stride.cpp
test/verify/quant_conv_padding_stride.cpp
+6
-3
test/verify/run_verify.cpp
test/verify/run_verify.cpp
+2
-1
test/verify/test_conv.cpp
test/verify/test_conv.cpp
+7
-5
test/verify/test_conv2.cpp
test/verify/test_conv2.cpp
+6
-5
test/verify/test_conv_add.cpp
test/verify/test_conv_add.cpp
+9
-7
test/verify/test_conv_add_1x1_diff_strides.cpp
test/verify/test_conv_add_1x1_diff_strides.cpp
+9
-7
test/verify/test_conv_add_relu.cpp
test/verify/test_conv_add_relu.cpp
+9
-7
No files found.
src/include/migraphx/op/quant_convolution.hpp
View file @
6a72e8fc
...
...
@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
...
...
@@ -87,11 +88,13 @@ struct quant_convolution
}
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
supported_types
,
t
))
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type"
);
}
t
=
shape
::
int32_type
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
auto
padding_size
=
padding
.
size
();
...
...
@@ -107,8 +110,11 @@ struct quant_convolution
stride
[
i
]
+
1
)));
}
return
inputs
[
0
].
with_lens
(
t
,
output_lens
);
if
(
t
==
shape
::
int8_type
)
{
return
inputs
[
0
].
with_lens
(
shape
::
int32_type
,
output_lens
);
}
// else fp8 conv
return
inputs
[
0
].
with_lens
(
shape
::
float_type
,
output_lens
);
}
size_t
kdims
()
const
...
...
src/targets/gpu/device_name.cpp
View file @
6a72e8fc
...
...
@@ -49,6 +49,12 @@ std::string get_device_name()
return
props
.
gcnArchName
;
}
bool
gfx_has_fp8_intrinsics
()
{
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
6a72e8fc
...
...
@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
auto
input_arg_t
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
...
...
@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
float_type
and
input_arg_t
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
...
...
@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
...
...
@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
...
...
@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported.
if
(
is_float
and
name
==
"convert"
)
{
if
(
result_type
==
shape
::
fp8e4m3fnuz_type
)
{
return
false
;
}
// else
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
...
...
@@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
gemm_based_op
=
r
.
result
;
//
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8/fp8 types
if
(
std
::
any_of
(
gemm_based_op
->
inputs
().
begin
(),
gemm_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
fp8e4m3fnuz_type
},
i
->
get_shape
().
type
());
}))
return
;
static
size_t
counter
=
0
;
...
...
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
6a72e8fc
...
...
@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
bool
gfx_has_fp8_intrinsics
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/mlir.cpp
View file @
6a72e8fc
...
...
@@ -300,6 +300,8 @@ struct mlir_program
result
=
mlirF32TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
result
=
mlirF16TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
fp8e4m3fnuz_type
)
result
=
mlirFloat8E4M3FNUZTypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
...
...
src/targets/gpu/rocblas.cpp
View file @
6a72e8fc
...
...
@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
return
gfx_has_fp8_intrinsics
();
#endif
}
...
...
src/targets/gpu/target.cpp
View file @
6a72e8fc
...
...
@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// whiltelist supported Ops for the FP8
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
if
(
not
gpu
::
gfx_has_fp8_intrinsics
())
{
unsupported_fp8_ops
.
insert
(
"convolution"
);
unsupported_fp8_ops
.
insert
(
"quant_convolution"
);
}
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
...
...
test/verify/main.cpp
View file @
6a72e8fc
...
...
@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim"
,
"test_instancenorm_large_3d<migraphx::shape::float_type>"
,
"test_instancenorm_large_3d<migraphx::shape::half_type>"
});
rv
.
disable_test_for
(
"gpu"
,
{
"test_conv_bn_add"
});
rv
.
run
(
argc
,
argv
);
}
test/verify/quant_conv.cpp
View file @
6a72e8fc
...
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv
:
verify_program
<
quant_conv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv
:
verify_program
<
quant_conv
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_1.cpp
View file @
6a72e8fc
...
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_1
:
verify_program
<
quant_conv_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_1
:
verify_program
<
quant_conv_1
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv_1
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_1d.cpp
View file @
6a72e8fc
...
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_1d
:
verify_program
<
quant_conv_1d
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_1d
:
verify_program
<
quant_conv_1d
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
...
...
@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return
p
;
}
};
template
struct
quant_conv_1d
<
migraphx
::
shape
::
int8_type
>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
test/verify/quant_conv_2.cpp
View file @
6a72e8fc
...
...
@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_2
:
verify_program
<
quant_conv_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_2
:
verify_program
<
quant_conv_2
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
16
,
16
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
16
,
16
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
16
,
16
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
16
,
16
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
}
};
template
struct
quant_conv_2
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_padding.cpp
View file @
6a72e8fc
...
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_padding
:
verify_program
<
quant_conv_padding
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_padding
:
verify_program
<
quant_conv_padding
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
1
,
1
}}}),
...
...
@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return
p
;
}
};
template
struct
quant_conv_padding
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_padding
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/quant_conv_padding_stride.cpp
View file @
6a72e8fc
...
...
@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
quant_conv_padding_stride
:
verify_program
<
quant_conv_padding_stride
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
quant_conv_padding_stride
:
verify_program
<
quant_conv_padding_stride
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
DT
ype
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_t
ype
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
c_shape
{
DT
ype
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
mm
->
add_parameter
(
"c"
,
c_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}}}),
...
...
@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return
p
;
}
};
template
struct
quant_conv_padding_stride
<
migraphx
::
shape
::
int8_type
>;
template
struct
quant_conv_padding_stride
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/run_verify.cpp
View file @
6a72e8fc
...
...
@@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
{
migraphx
::
target
t
=
migraphx
::
make_target
(
"ref"
);
auto_print
pp
{
p
,
t
.
name
()};
compile_check
(
p
,
t
,
c_opts
);
auto
trace_target
=
migraphx
::
string_value_of
(
MIGRAPHX_TRACE_TEST_COMPILE
{});
compile_check
(
p
,
t
,
c_opts
,
(
trace_target
==
"ref"
));
return
p
.
eval
(
std
::
move
(
inputs
));
}
std
::
pair
<
migraphx
::
program
,
std
::
vector
<
migraphx
::
argument
>>
...
...
test/verify/test_conv.cpp
View file @
6a72e8fc
...
...
@@ -27,17 +27,19 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_conv
:
verify_program
<
test_conv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_conv
:
verify_program
<
test_conv
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
DType
,
{
4
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
DType
,
{
4
,
3
,
3
,
3
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
input
,
weights
);
return
p
;
}
};
template
struct
test_conv
<
migraphx
::
shape
::
float_type
>;
template
struct
test_conv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_conv2.cpp
View file @
6a72e8fc
...
...
@@ -27,16 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_conv2
:
verify_program
<
test_conv2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_conv2
:
verify_program
<
test_conv2
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
512
,
28
,
28
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
512
,
1
,
1
}});
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
DType
,
{
1
,
512
,
28
,
28
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
DType
,
{
256
,
512
,
1
,
1
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}}}),
...
...
@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return
p
;
}
};
template
struct
test_conv2
<
migraphx
::
shape
::
float_type
>;
template
struct
test_conv2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_conv_add.cpp
View file @
6a72e8fc
...
...
@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_conv_add
:
verify_program
<
test_conv_add
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_conv_add
:
verify_program
<
test_conv_add
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
4
,
4
}});
auto
w
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
2
,
8
,
3
,
3
}},
1
));
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
4
,
4
}});
auto
v
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
2
,
8
,
3
,
3
}},
2
));
auto
x
=
mm
->
add_parameter
(
"x"
,
{
DType
,
{
1
,
8
,
4
,
4
}});
auto
w
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
DType
,
{
2
,
8
,
3
,
3
}},
1
));
auto
y
=
mm
->
add_parameter
(
"y"
,
{
DType
,
{
1
,
8
,
4
,
4
}});
auto
v
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
DType
,
{
2
,
8
,
3
,
3
}},
2
));
auto
conv1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
x
,
w
);
auto
conv2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
y
,
v
);
auto
sum
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv1
,
conv2
);
...
...
@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add>
return
p
;
}
};
template
struct
test_conv_add
<
migraphx
::
shape
::
float_type
>;
template
struct
test_conv_add
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_conv_add_1x1_diff_strides.cpp
View file @
6a72e8fc
...
...
@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_conv_add_1x1_diff_strides
:
verify_program
<
test_conv_add_1x1_diff_strides
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_conv_add_1x1_diff_strides
:
verify_program
<
test_conv_add_1x1_diff_strides
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
2
,
2
}});
auto
w
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
2
,
8
,
1
,
1
}},
1
));
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
4
,
4
}});
auto
v
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
2
,
8
,
1
,
1
}},
2
));
auto
x
=
mm
->
add_parameter
(
"x"
,
{
DType
,
{
1
,
8
,
2
,
2
}});
auto
w
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
DType
,
{
2
,
8
,
1
,
1
}},
1
));
auto
y
=
mm
->
add_parameter
(
"y"
,
{
DType
,
{
1
,
8
,
4
,
4
}});
auto
v
=
mm
->
add_literal
(
migraphx
::
generate_literal
({
DType
,
{
2
,
8
,
1
,
1
}},
2
));
auto
conv1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
x
,
w
);
auto
conv2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
2
,
2
}}}),
y
,
v
);
...
...
@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
return
p
;
}
};
template
struct
test_conv_add_1x1_diff_strides
<
migraphx
::
shape
::
float_type
>;
template
struct
test_conv_add_1x1_diff_strides
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_conv_add_relu.cpp
View file @
6a72e8fc
...
...
@@ -28,18 +28,17 @@
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct
test_conv_add_relu
:
verify_program
<
test_conv_add_relu
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_conv_add_relu
:
verify_program
<
test_conv_add_relu
<
DType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
bias_literal
=
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}},
{
2.0
f
,
2.0
f
,
2.0
f
,
2.0
f
}};
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
DType
,
{
4
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
DType
,
{
4
,
3
,
3
,
3
}});
auto
bias_literal
=
migraphx
::
literal
{
migraphx
::
shape
{
DType
,
{
4
}},
{
2.0
f
,
2.0
f
,
2.0
f
,
2.0
f
}};
auto
bias
=
mm
->
add_literal
(
bias_literal
);
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
input
,
weights
);
auto
bcast_bias
=
mm
->
add_instruction
(
...
...
@@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu>
return
p
;
}
};
template
struct
test_conv_add_relu
<
migraphx
::
shape
::
float_type
>;
template
struct
test_conv_add_relu
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment