Commit 6f768035 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_mlir_fp8' into miopen_fp8

parents da7717ce b2542239
......@@ -97,7 +97,8 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2<DT
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}},
{"stride", {2, 2}},
{"lengths", {3, 3}}}),
{"lengths", {3, 3}},
{"dilations", {1, 1}}}),
relu);
return p;
}
......
......@@ -46,4 +46,5 @@ struct test_conv_group_add : verify_program<test_conv_group_add<DType>>
}
};
template struct test_conv_group_add<migraphx::shape::float_type>;
// grouped convolutions are not supported with MLIR therefore disable it
// template struct test_conv_group_add<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
......
......@@ -34,7 +34,7 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
......@@ -42,4 +42,5 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>>
}
};
template struct test_conv_relu<migraphx::shape::float_type>;
template struct test_conv_relu<migraphx::shape::half_type>;
template struct test_conv_relu<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -34,41 +34,52 @@
Adding this because HIP fmod sign changes when y = 0 resulting in nan and -nan not beign
consistent between ref and gpu implementations.
*/
migraphx::instruction_ref add_epsilon(migraphx::module& m, migraphx::instruction_ref y)
migraphx::instruction_ref add_epsilon(migraphx::module& m,
migraphx::instruction_ref y,
migraphx::shape::type_t dtype = migraphx::shape::float_type)
{
auto zero = m.add_literal(0.0f);
auto eps = m.add_literal(1e-3f);
auto zero = m.add_literal(migraphx::literal{migraphx::shape{dtype}, {0.0f}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape{dtype}, {1e-3f}});
auto op_y = add_common_op(m, migraphx::make_op("equal"), {y, zero});
return add_common_op(m, migraphx::make_op("where"), {op_y, eps, y});
}
struct test_fmod : verify_program<test_fmod>
template <migraphx::shape::type_t DType>
struct test_fmod : verify_program<test_fmod<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {64}};
migraphx::shape s{DType, {64}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto op_where = add_epsilon(*mm, y);
auto op_where = add_epsilon(*mm, y, DType);
mm->add_instruction(migraphx::make_op("fmod"), x, op_where);
return p;
}
};
template struct test_fmod<migraphx::shape::float_type>;
template struct test_fmod<migraphx::shape::half_type>;
template struct test_fmod<migraphx::shape::fp8e4m3fnuz_type>;
struct test_mod : verify_program<test_mod>
template <migraphx::shape::type_t DType>
struct test_mod : verify_program<test_mod<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {64}};
migraphx::shape s{DType, {64}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto op_where = add_epsilon(*mm, y);
auto op_where = add_epsilon(*mm, y, DType);
mm->add_instruction(migraphx::make_op("mod"), x, op_where);
return p;
}
};
// TODO: check if requires FP8 test
template struct test_mod<migraphx::shape::float_type>;
// TODO: Fix half type test
// template struct test_mod<migraphx::shape::half_type>;
template struct test_mod<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -41,4 +41,5 @@ struct test_gemm : verify_program<test_gemm<DType>>
};
template struct test_gemm<migraphx::shape::float_type>;
template struct test_gemm<migraphx::shape::half_type>;
template struct test_gemm<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -49,4 +49,5 @@ struct test_gemm_copy : verify_program<test_gemm_copy<DType>>
};
template struct test_gemm_copy<migraphx::shape::float_type>;
template struct test_gemm_copy<migraphx::shape::half_type>;
template struct test_gemm_copy<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -41,4 +41,5 @@ struct test_gemm_ex : verify_program<test_gemm_ex<DType>>
}
};
template struct test_gemm_ex<migraphx::shape::float_type>;
template struct test_gemm_ex<migraphx::shape::half_type>;
template struct test_gemm_ex<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -43,4 +43,5 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea<DType>>
};
template struct test_gemm_transposea<migraphx::shape::float_type>;
template struct test_gemm_transposea<migraphx::shape::half_type>;
template struct test_gemm_transposea<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -44,4 +44,5 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex<DType>>
};
template struct test_gemm_transposea_ex<migraphx::shape::float_type>;
template struct test_gemm_transposea_ex<migraphx::shape::half_type>;
template struct test_gemm_transposea_ex<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -44,4 +44,5 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab<DType>>
};
template struct test_gemm_transposeab<migraphx::shape::float_type>;
template struct test_gemm_transposeab<migraphx::shape::half_type>;
template struct test_gemm_transposeab<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -43,4 +43,5 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb<DType>>
};
template struct test_gemm_transposeb<migraphx::shape::float_type>;
template struct test_gemm_transposeb<migraphx::shape::half_type>;
template struct test_gemm_transposeb<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -26,6 +26,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType>
struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>>
{
......@@ -43,4 +44,5 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>>
};
template struct test_gemm_transposeb_ex<migraphx::shape::float_type>;
template struct test_gemm_transposeb_ex<migraphx::shape::half_type>;
template struct test_gemm_transposeb_ex<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -36,7 +36,7 @@ struct test_max_pooling_ceil_3d : verify_program<test_max_pooling_ceil_3d>
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::max, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
migraphx::op::pooling_mode::max, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, {1, 1, 1}, true};
mm->add_instruction(op, input);
return p;
}
......
......@@ -49,4 +49,5 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a<DType>>
};
template struct test_mul_dot_a<migraphx::shape::float_type>;
template struct test_mul_dot_a<migraphx::shape::half_type>;
template struct test_mul_dot_a<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -28,8 +28,8 @@
#include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType>
struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
{
migraphx::program create_program() const
{
......@@ -50,4 +50,5 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
};
template struct test_mul_dot_b<migraphx::shape::float_type>;
template struct test_mul_dot_b<migraphx::shape::half_type>;
template struct test_mul_dot_b<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -22,11 +22,11 @@
* THE SOFTWARE.
*/
#include "migraphx/float8.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/float8.hpp>
template <class T>
struct test_nearbyint : verify_program<test_nearbyint<T>>
......
......@@ -22,12 +22,12 @@
* THE SOFTWARE.
*/
#include "migraphx/shape.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
template <migraphx::shape::type_t DType>
struct test_reduce_add : verify_program<test_reduce_add<DType>>
......
......@@ -21,11 +21,11 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/shape.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/shape.hpp>
template <migraphx::shape::type_t DType>
struct test_scatternd : verify_program<test_scatternd<DType>>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,24 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_relu_half : verify_program<test_conv_relu_half>
struct test_scatternd_max : verify_program<test_scatternd_max>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment