Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13ef4148
Commit
13ef4148
authored
Nov 20, 2023
by
Umang Yadav
Browse files
add test for rsqrt and remove old-styple-cast
parent
6155c782
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
8 deletions
+13
-8
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+0
-1
test/verify/test_rsqrt.cpp
test/verify/test_rsqrt.cpp
+13
-7
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
13ef4148
...
...
@@ -25,7 +25,6 @@
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
...
...
test/verify/test_rsqrt.cpp
View file @
13ef4148
...
...
@@ -23,22 +23,26 @@
*/
#include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_rsqrt
:
verify_program
<
test_rsqrt
>
template
<
typename
CType
>
struct
test_rsqrt
:
verify_program
<
test_rsqrt
<
CType
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
::
type_t
dtype
=
migraphx
::
shape
::
get_type
<
CType
>
();
std
::
vector
<
size_t
>
input_lens
{
1
,
3
,
16
,
16
};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_
type
,
input_lens
};
migraphx
::
shape
s
{
d
type
,
input_lens
};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
min_val
=
mm
->
add_literal
(
1.0
f
);
auto
max_val
=
mm
->
add_literal
(
std
::
numeric_limits
<
float
>::
max
());
min_val
=
mm
->
add_instruction
(
auto
min_val
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
dtype
},
{
1.0
}});
auto
max_val
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
dtype
},
{
std
::
numeric_limits
<
CType
>::
max
()}});
min_val
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
min_val
);
max_val
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
max_val
);
...
...
@@ -48,4 +52,6 @@ struct test_rsqrt : verify_program<test_rsqrt>
};
};
// TOOD : Add FP8 test
template
struct
test_rsqrt
<
float
>;
template
struct
test_rsqrt
<
migraphx
::
half
>;
template
struct
test_rsqrt
<
migraphx
::
fp8
::
fp8e4m3fnuz
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment