Commit 6f768035 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_mlir_fp8' into miopen_fp8

parents da7717ce b2542239
...@@ -97,7 +97,8 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2<DT ...@@ -97,7 +97,8 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2<DT
{{"mode", migraphx::op::pooling_mode::average}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}}, {"padding", {1, 1}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {3, 3}}}), {"lengths", {3, 3}},
{"dilations", {1, 1}}}),
relu); relu);
return p; return p;
} }
......
...@@ -46,4 +46,5 @@ struct test_conv_group_add : verify_program<test_conv_group_add<DType>> ...@@ -46,4 +46,5 @@ struct test_conv_group_add : verify_program<test_conv_group_add<DType>>
} }
}; };
template struct test_conv_group_add<migraphx::shape::float_type>; template struct test_conv_group_add<migraphx::shape::float_type>;
// grouped convolutions are not supported with MLIR therefore disable it
// template struct test_conv_group_add<migraphx::shape::fp8e4m3fnuz_type>; // template struct test_conv_group_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling<DType>> ...@@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction( auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
......
...@@ -34,7 +34,7 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>> ...@@ -34,7 +34,7 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv); mm->add_instruction(migraphx::make_op("relu"), conv);
...@@ -42,4 +42,5 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>> ...@@ -42,4 +42,5 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>>
} }
}; };
template struct test_conv_relu<migraphx::shape::float_type>; template struct test_conv_relu<migraphx::shape::float_type>;
template struct test_conv_relu<migraphx::shape::half_type>;
template struct test_conv_relu<migraphx::shape::fp8e4m3fnuz_type>; template struct test_conv_relu<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -34,41 +34,52 @@ ...@@ -34,41 +34,52 @@
Adding this because HIP fmod sign changes when y = 0 resulting in nan and -nan not beign Adding this because HIP fmod sign changes when y = 0 resulting in nan and -nan not beign
consistent between ref and gpu implementations. consistent between ref and gpu implementations.
*/ */
migraphx::instruction_ref add_epsilon(migraphx::module& m, migraphx::instruction_ref y) migraphx::instruction_ref add_epsilon(migraphx::module& m,
migraphx::instruction_ref y,
migraphx::shape::type_t dtype = migraphx::shape::float_type)
{ {
auto zero = m.add_literal(0.0f); auto zero = m.add_literal(migraphx::literal{migraphx::shape{dtype}, {0.0f}});
auto eps = m.add_literal(1e-3f); auto eps = m.add_literal(migraphx::literal{migraphx::shape{dtype}, {1e-3f}});
auto op_y = add_common_op(m, migraphx::make_op("equal"), {y, zero}); auto op_y = add_common_op(m, migraphx::make_op("equal"), {y, zero});
return add_common_op(m, migraphx::make_op("where"), {op_y, eps, y}); return add_common_op(m, migraphx::make_op("where"), {op_y, eps, y});
} }
struct test_fmod : verify_program<test_fmod> template <migraphx::shape::type_t DType>
struct test_fmod : verify_program<test_fmod<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {64}}; migraphx::shape s{DType, {64}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto op_where = add_epsilon(*mm, y); auto op_where = add_epsilon(*mm, y, DType);
mm->add_instruction(migraphx::make_op("fmod"), x, op_where); mm->add_instruction(migraphx::make_op("fmod"), x, op_where);
return p; return p;
} }
}; };
template struct test_fmod<migraphx::shape::float_type>;
template struct test_fmod<migraphx::shape::half_type>;
template struct test_fmod<migraphx::shape::fp8e4m3fnuz_type>;
struct test_mod : verify_program<test_mod> template <migraphx::shape::type_t DType>
struct test_mod : verify_program<test_mod<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {64}}; migraphx::shape s{DType, {64}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto op_where = add_epsilon(*mm, y); auto op_where = add_epsilon(*mm, y, DType);
mm->add_instruction(migraphx::make_op("mod"), x, op_where); mm->add_instruction(migraphx::make_op("mod"), x, op_where);
return p; return p;
} }
}; };
// TODO: check if requires FP8 test
template struct test_mod<migraphx::shape::float_type>;
// TODO: Fix half type test
// template struct test_mod<migraphx::shape::half_type>;
template struct test_mod<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -41,4 +41,5 @@ struct test_gemm : verify_program<test_gemm<DType>> ...@@ -41,4 +41,5 @@ struct test_gemm : verify_program<test_gemm<DType>>
}; };
template struct test_gemm<migraphx::shape::float_type>; template struct test_gemm<migraphx::shape::float_type>;
template struct test_gemm<migraphx::shape::half_type>;
template struct test_gemm<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -49,4 +49,5 @@ struct test_gemm_copy : verify_program<test_gemm_copy<DType>> ...@@ -49,4 +49,5 @@ struct test_gemm_copy : verify_program<test_gemm_copy<DType>>
}; };
template struct test_gemm_copy<migraphx::shape::float_type>; template struct test_gemm_copy<migraphx::shape::float_type>;
template struct test_gemm_copy<migraphx::shape::half_type>;
template struct test_gemm_copy<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_copy<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -41,4 +41,5 @@ struct test_gemm_ex : verify_program<test_gemm_ex<DType>> ...@@ -41,4 +41,5 @@ struct test_gemm_ex : verify_program<test_gemm_ex<DType>>
} }
}; };
template struct test_gemm_ex<migraphx::shape::float_type>; template struct test_gemm_ex<migraphx::shape::float_type>;
template struct test_gemm_ex<migraphx::shape::half_type>;
template struct test_gemm_ex<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -43,4 +43,5 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea<DType>> ...@@ -43,4 +43,5 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea<DType>>
}; };
template struct test_gemm_transposea<migraphx::shape::float_type>; template struct test_gemm_transposea<migraphx::shape::float_type>;
template struct test_gemm_transposea<migraphx::shape::half_type>;
template struct test_gemm_transposea<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_transposea<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -44,4 +44,5 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex<DType>> ...@@ -44,4 +44,5 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex<DType>>
}; };
template struct test_gemm_transposea_ex<migraphx::shape::float_type>; template struct test_gemm_transposea_ex<migraphx::shape::float_type>;
template struct test_gemm_transposea_ex<migraphx::shape::half_type>;
template struct test_gemm_transposea_ex<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_transposea_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -44,4 +44,5 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab<DType>> ...@@ -44,4 +44,5 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab<DType>>
}; };
template struct test_gemm_transposeab<migraphx::shape::float_type>; template struct test_gemm_transposeab<migraphx::shape::float_type>;
template struct test_gemm_transposeab<migraphx::shape::half_type>;
template struct test_gemm_transposeab<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_transposeab<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -43,4 +43,5 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb<DType>> ...@@ -43,4 +43,5 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb<DType>>
}; };
template struct test_gemm_transposeb<migraphx::shape::float_type>; template struct test_gemm_transposeb<migraphx::shape::float_type>;
template struct test_gemm_transposeb<migraphx::shape::half_type>;
template struct test_gemm_transposeb<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_transposeb<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> template <migraphx::shape::type_t DType>
struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>> struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>>
{ {
...@@ -43,4 +44,5 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>> ...@@ -43,4 +44,5 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>>
}; };
template struct test_gemm_transposeb_ex<migraphx::shape::float_type>; template struct test_gemm_transposeb_ex<migraphx::shape::float_type>;
template struct test_gemm_transposeb_ex<migraphx::shape::half_type>;
template struct test_gemm_transposeb_ex<migraphx::shape::fp8e4m3fnuz_type>; template struct test_gemm_transposeb_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -36,7 +36,7 @@ struct test_max_pooling_ceil_3d : verify_program<test_max_pooling_ceil_3d> ...@@ -36,7 +36,7 @@ struct test_max_pooling_ceil_3d : verify_program<test_max_pooling_ceil_3d>
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{ auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::max, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true}; migraphx::op::pooling_mode::max, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, {1, 1, 1}, true};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -49,4 +49,5 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a<DType>> ...@@ -49,4 +49,5 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a<DType>>
}; };
template struct test_mul_dot_a<migraphx::shape::float_type>; template struct test_mul_dot_a<migraphx::shape::float_type>;
template struct test_mul_dot_a<migraphx::shape::half_type>;
template struct test_mul_dot_a<migraphx::shape::fp8e4m3fnuz_type>; template struct test_mul_dot_a<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,8 +28,8 @@ ...@@ -28,8 +28,8 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> template <migraphx::shape::type_t DType>
struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>> struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -50,4 +50,5 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>> ...@@ -50,4 +50,5 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
}; };
template struct test_mul_dot_b<migraphx::shape::float_type>; template struct test_mul_dot_b<migraphx::shape::float_type>;
template struct test_mul_dot_b<migraphx::shape::half_type>;
template struct test_mul_dot_b<migraphx::shape::fp8e4m3fnuz_type>; template struct test_mul_dot_b<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -22,11 +22,11 @@ ...@@ -22,11 +22,11 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/float8.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/float8.hpp>
template <class T> template <class T>
struct test_nearbyint : verify_program<test_nearbyint<T>> struct test_nearbyint : verify_program<test_nearbyint<T>>
......
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
template <migraphx::shape::type_t DType> template <migraphx::shape::type_t DType>
struct test_reduce_add : verify_program<test_reduce_add<DType>> struct test_reduce_add : verify_program<test_reduce_add<DType>>
......
...@@ -21,11 +21,11 @@ ...@@ -21,11 +21,11 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/shape.hpp>
template <migraphx::shape::type_t DType> template <migraphx::shape::type_t DType>
struct test_scatternd : verify_program<test_scatternd<DType>> struct test_scatternd : verify_program<test_scatternd<DType>>
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,24 +21,31 @@ ...@@ -21,24 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_relu_half : verify_program<test_conv_relu_half> struct test_scatternd_max : verify_program<test_scatternd_max>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto dtype = migraphx::shape::float_type;
mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); auto itype = migraphx::shape::int64_type;
auto weights = migraphx::shape ds{dtype, {8}};
mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); migraphx::shape is{itype, {4, 1}};
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); migraphx::shape us{dtype, {4}};
mm->add_instruction(migraphx::make_op("relu"), conv); std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment