Commit 13ef4148 authored by Umang Yadav's avatar Umang Yadav
Browse files

add test for rsqrt and remove old-styple-cast

parent 6155c782
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal" #pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wc++20-extensions" #pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__ #endif // __clang__
......
...@@ -23,22 +23,26 @@ ...@@ -23,22 +23,26 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_rsqrt : verify_program<test_rsqrt> template <typename CType>
struct test_rsqrt : verify_program<test_rsqrt<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
std::vector<size_t> input_lens{1, 3, 16, 16}; std::vector<size_t> input_lens{1, 3, 16, 16};
migraphx::shape s{migraphx::shape::float_type, input_lens}; migraphx::shape s{dtype, input_lens};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.0f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.0}});
auto max_val = mm->add_literal(std::numeric_limits<float>::max()); auto max_val = mm->add_literal(
min_val = mm->add_instruction( migraphx::literal{migraphx::shape{dtype}, {std::numeric_limits<CType>::max()}});
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
...@@ -48,4 +52,6 @@ struct test_rsqrt : verify_program<test_rsqrt> ...@@ -48,4 +52,6 @@ struct test_rsqrt : verify_program<test_rsqrt>
}; };
}; };
// TOOD : Add FP8 test template struct test_rsqrt<float>;
template struct test_rsqrt<migraphx::half>;
template struct test_rsqrt<migraphx::fp8::fp8e4m3fnuz>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment