Commit 13ef4148 authored by Umang Yadav's avatar Umang Yadav
Browse files

add test for rsqrt and remove old-styple-cast

parent 6155c782
......@@ -25,7 +25,6 @@
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
......
......@@ -23,22 +23,26 @@
*/
#include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_rsqrt : verify_program<test_rsqrt>
template <typename CType>
struct test_rsqrt : verify_program<test_rsqrt<CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
std::vector<size_t> input_lens{1, 3, 16, 16};
migraphx::shape s{migraphx::shape::float_type, input_lens};
migraphx::shape s{dtype, input_lens};
auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.0f);
auto max_val = mm->add_literal(std::numeric_limits<float>::max());
min_val = mm->add_instruction(
auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.0}});
auto max_val = mm->add_literal(
migraphx::literal{migraphx::shape{dtype}, {std::numeric_limits<CType>::max()}});
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
......@@ -48,4 +52,6 @@ struct test_rsqrt : verify_program<test_rsqrt>
};
};
// TOOD : Add FP8 test
template struct test_rsqrt<float>;
template struct test_rsqrt<migraphx::half>;
template struct test_rsqrt<migraphx::fp8::fp8e4m3fnuz>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment