Commit 538dbd75 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into resize_op

parents c7161d99 e3e00547
...@@ -27,13 +27,15 @@ ...@@ -27,13 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_pow : verify_program<test_pow> template <typename CType>
struct test_pow : verify_program<test_pow<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{migraphx::shape::float_type, {6}}; auto* mm = p.get_main_module();
migraphx::shape s{dtype, {6}};
std::vector<float> vec_e(s.elements(), 2.0f); std::vector<float> vec_e(s.elements(), 2.0f);
auto b = mm->add_parameter("x", s); auto b = mm->add_parameter("x", s);
auto e = mm->add_literal(migraphx::literal(s, vec_e)); auto e = mm->add_literal(migraphx::literal(s, vec_e));
...@@ -41,3 +43,6 @@ struct test_pow : verify_program<test_pow> ...@@ -41,3 +43,6 @@ struct test_pow : verify_program<test_pow>
return p; return p;
} }
}; };
template struct test_pow<float>;
template struct test_pow<migraphx::half>;
template struct test_pow<migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,14 +27,16 @@ ...@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
struct test_reduce_add : verify_program<test_reduce_add> template <migraphx::shape::type_t DType>
struct test_reduce_add : verify_program<test_reduce_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 1000, 2, 2}}; migraphx::shape s{DType, {4, 1000, 2, 2}};
migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}}; migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto reduce_mean = auto reduce_mean =
...@@ -46,3 +48,6 @@ struct test_reduce_add : verify_program<test_reduce_add> ...@@ -46,3 +48,6 @@ struct test_reduce_add : verify_program<test_reduce_add>
return p; return p;
}; };
}; };
template struct test_reduce_add<migraphx::shape::float_type>;
template struct test_reduce_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,14 +28,14 @@ ...@@ -28,14 +28,14 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc> template <migraphx::shape::type_t DType>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation( auto s = migraphx::shape::from_permutation(DType, {4, 256, 2, 2}, {0, 2, 3, 1});
migraphx::shape::float_type, {4, 256, 2, 2}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x); auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x);
auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce); auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce);
...@@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc> ...@@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc>
return p; return p;
}; };
}; };
template struct test_reduce_mean_nhwc<migraphx::shape::float_type>;
template struct test_reduce_mean_nhwc<migraphx::shape::half_type>;
template struct test_reduce_mean_nhwc<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -51,6 +51,22 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap ...@@ -51,6 +51,22 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap
template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_max,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_mean,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_min,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_prod,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum,
1,
migraphx::shape::fp8e4m3fnuz_type>;
struct test_reduce_mean_1 : verify_program<test_reduce_mean_1> struct test_reduce_mean_1 : verify_program<test_reduce_mean_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -56,3 +56,19 @@ template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::sha ...@@ -56,3 +56,19 @@ template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::sha
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod,
-2,
migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_roialign : verify_program<test_roialign> template <migraphx::shape::type_t DType>
struct test_roialign : verify_program<test_roialign<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {5, 4, 10, 10}}; migraphx::shape x_s{DType, {5, 4, 10, 10}};
migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; migraphx::shape roi_s{DType, {5, 4}};
migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; migraphx::shape ind_s{migraphx::shape::int64_type, {5}};
std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1}; std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1};
...@@ -44,10 +45,10 @@ struct test_roialign : verify_program<test_roialign> ...@@ -44,10 +45,10 @@ struct test_roialign : verify_program<test_roialign>
auto roi = mm->add_parameter("roi", roi_s); auto roi = mm->add_parameter("roi", roi_s);
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r = mm->add_instruction(migraphx::make_op("roialign", auto r = mm->add_instruction(migraphx::make_op("roialign",
{{"spatial_scale", 1.0}, {{"spatial_scale", 1.0},
{"output_height", 5}, {"output_height", 5},
{"output_width", 5}, {"output_width", 5},
{"sampling_ratio", 2}}), {"sampling_ratio", 2}}),
x, x,
roi, roi,
ind); ind);
...@@ -56,3 +57,7 @@ struct test_roialign : verify_program<test_roialign> ...@@ -56,3 +57,7 @@ struct test_roialign : verify_program<test_roialign>
return p; return p;
} }
}; };
template struct test_roialign<migraphx::shape::float_type>;
template struct test_roialign<migraphx::shape::half_type>;
template struct test_roialign<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -23,22 +23,26 @@ ...@@ -23,22 +23,26 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_rsqrt : verify_program<test_rsqrt> template <typename CType>
struct test_rsqrt : verify_program<test_rsqrt<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
std::vector<size_t> input_lens{1, 3, 16, 16}; std::vector<size_t> input_lens{1, 3, 16, 16};
migraphx::shape s{migraphx::shape::float_type, input_lens}; migraphx::shape s{dtype, input_lens};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.0f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.0}});
auto max_val = mm->add_literal(std::numeric_limits<float>::max()); auto max_val = mm->add_literal(
min_val = mm->add_instruction( migraphx::literal{migraphx::shape{dtype}, {std::numeric_limits<CType>::max()}});
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
...@@ -47,3 +51,7 @@ struct test_rsqrt : verify_program<test_rsqrt> ...@@ -47,3 +51,7 @@ struct test_rsqrt : verify_program<test_rsqrt>
return p; return p;
}; };
}; };
template struct test_rsqrt<float>;
template struct test_rsqrt<migraphx::half>;
template struct test_rsqrt<migraphx::fp8::fp8e4m3fnuz>;
...@@ -25,18 +25,19 @@ ...@@ -25,18 +25,19 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/shape.hpp>
struct test_scatternd : verify_program<test_scatternd> template <migraphx::shape::type_t DType>
struct test_scatternd : verify_program<test_scatternd<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type; auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {1}}; migraphx::shape ds{DType, {1}};
migraphx::shape is{itype, {4, 1}}; migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}}; migraphx::shape us{DType, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7}; std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto ld = mm->add_literal(migraphx::literal{ds, {1}}); auto ld = mm->add_literal(migraphx::literal{ds, {1}});
...@@ -51,3 +52,7 @@ struct test_scatternd : verify_program<test_scatternd> ...@@ -51,3 +52,7 @@ struct test_scatternd : verify_program<test_scatternd>
return p; return p;
} }
}; };
template struct test_scatternd<migraphx::shape::float_type>;
template struct test_scatternd<migraphx::shape::half_type>;
template struct test_scatternd<migraphx::shape::fp8e4m3fnuz_type>;
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,23 +21,31 @@ ...@@ -21,23 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <limits>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/half.hpp>
struct test_isnan_half : verify_program<test_isnan_half> struct test_scatternd_max : verify_program<test_scatternd_max>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2}}); auto dtype = migraphx::shape::float_type;
auto l0 = mm->add_literal(std::numeric_limits<migraphx::half>::quiet_NaN()); auto itype = migraphx::shape::int64_type;
x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0); migraphx::shape ds{dtype, {8}};
mm->add_instruction(migraphx::make_op("isnan"), x); migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
return p; return p;
} }
}; };
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,42 +21,31 @@ ...@@ -21,42 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #include "verify_program.hpp"
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/argument.hpp> struct test_scatternd_max_duplicate_idx : verify_program<test_scatternd_max_duplicate_idx>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_gather
{ {
op::gather op; migraphx::program create_program() const
template <class Self, class F>
static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); migraphx::program p;
} auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
std::string name() const { return "gpu::gather"; } auto itype = migraphx::shape::int64_type;
shape compute_shape(std::vector<shape> inputs) const; migraphx::shape ds{dtype, {8}};
argument migraphx::shape is{itype, {4, 1}};
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; migraphx::shape us{dtype, {4}};
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::vector<int64_t> ind_vec{4, 7, 4, 7};
{
return shapes.size() - 1; auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
return p;
} }
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,41 +21,31 @@ ...@@ -21,41 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP #include "verify_program.hpp"
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP #include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/argument.hpp> struct test_scatternd_min : verify_program<test_scatternd_min>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_pad
{ {
op::pad op; migraphx::program create_program() const
template <class Self, class F>
static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); migraphx::program p;
} auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
std::string name() const { return "gpu::pad"; } auto itype = migraphx::shape::int64_type;
shape compute_shape(std::vector<shape> inputs) const; migraphx::shape ds{dtype, {8}};
argument migraphx::shape is{itype, {4, 1}};
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; migraphx::shape us{dtype, {4}};
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::vector<int64_t> ind_vec{4, 3, 1, 7};
{
return shapes.size() - 1; auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
return p;
} }
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatternd_min_duplicate_idx : verify_program<test_scatternd_min_duplicate_idx>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 7, 4, 7};
auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
return p;
}
};
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sin : verify_program<test_sin> template <migraphx::shape::type_t DType>
struct test_sin : verify_program<test_sin<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {10}}; migraphx::shape s{DType, {10}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("sin"), x); mm->add_instruction(migraphx::make_op("sin"), x);
return p; return p;
} }
}; };
template struct test_sin<migraphx::shape::float_type>;
template struct test_sin<migraphx::shape::half_type>;
template struct test_sin<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sinh : verify_program<test_sinh> template <migraphx::shape::type_t DType>
struct test_sinh : verify_program<test_sinh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("sinh"), x); mm->add_instruction(migraphx::make_op("sinh"), x);
return p; return p;
} }
}; };
template struct test_sinh<migraphx::shape::float_type>;
template struct test_sinh<migraphx::shape::half_type>;
template struct test_sinh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>; ...@@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>;
template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sqrt : verify_program<test_sqrt> template <migraphx::shape::type_t DType>
struct test_sqrt : verify_program<test_sqrt<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs); mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p; return p;
} }
}; };
template struct test_sqrt<migraphx::shape::float_type>;
template struct test_sqrt<migraphx::shape::half_type>;
template struct test_sqrt<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_tan : verify_program<test_tan> template <migraphx::shape::type_t DType>
struct test_tan : verify_program<test_tan<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("tan"), x); mm->add_instruction(migraphx::make_op("tan"), x);
return p; return p;
} }
}; };
template struct test_tan<migraphx::shape::float_type>;
template struct test_tan<migraphx::shape::half_type>;
template struct test_tan<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,19 @@ ...@@ -27,14 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_tanh : verify_program<test_tanh> template <migraphx::shape::type_t DType>
struct test_tanh : verify_program<test_tanh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("tanh"), x); mm->add_instruction(migraphx::make_op("tanh"), x);
return p; return p;
} }
}; };
template struct test_tanh<migraphx::shape::float_type>;
template struct test_tanh<migraphx::shape::half_type>;
template struct test_tanh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where> template <migraphx::shape::type_t DType>
struct test_where : verify_program<test_where<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where> ...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where>
return p; return p;
}; };
}; };
template struct test_where<migraphx::shape::float_type>;
template struct test_where<migraphx::shape::half_type>;
template struct test_where<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -134,10 +134,11 @@ def check_correctness(gold_outputs, ...@@ -134,10 +134,11 @@ def check_correctness(gold_outputs,
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False ret = False
if verbose: if verbose:
print('\nOutput {} is incorrect ...'.format(i)) with np.printoptions(threshold=np.inf):
print('Expected value: \n{}'.format(gold_outputs[i])) print('\nOutput {} is incorrect ...'.format(i))
print('......') print('Expected value: \n{}\n'.format(gold_outputs[i]))
print('Actual value: \n{}\n'.format(outputs[i])) print('\n......\n')
print('Actual value: \n{}\n'.format(outputs[i]))
else: else:
print('Outputs do not match') print('Outputs do not match')
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment