Commit 7d10ba33 authored by umang yadav's avatar umang yadav
Browse files

Merge branch 'develop' into resnet50_partition

parents c8aea733 1d8840b8
...@@ -36,6 +36,7 @@ add_library(migraphx ...@@ -36,6 +36,7 @@ add_library(migraphx
argument.cpp argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
common.cpp common.cpp
common_dims.cpp
compile_src.cpp compile_src.cpp
convert_to_json.cpp convert_to_json.cpp
cpp_generator.cpp cpp_generator.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x > dim;
});
if(x < dim)
return start;
return it;
}
template <class Range>
static auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}
struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims,
std::vector<std::vector<std::size_t>>& paxes_map)
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
{
}
const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1;
std::size_t get() const { return *it / rem; }
bool is_end() const { return it == dims->end(); }
void next(std::size_t i = 1) { it += i; }
auto dims_for(std::size_t d) const
{
auto dim_end = compute_end_dim(it, dims->end(), d);
return range(it, dim_end);
}
void add_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
axes_map->push_back(std::move(axes));
}
void add_multi_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
std::transform(axes.begin(),
axes.end(),
std::back_inserter(*axes_map),
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
}
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{
if(rem != 1)
{
assert(start > 0);
naxes++;
start--;
}
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), start);
return axes;
}
};
static bool compute_common_dim(std::vector<std::size_t>& cd_dims,
common_dim_state& state1,
common_dim_state& state2)
{
assert(state1.get() <= state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);
if(naxes == 0)
return false;
// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return false;
auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size());
state1.rem = rem;
state2.rem = 1;
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state2.next();
return true;
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) > 0);
assert(elements(dims1) == elements(dims2));
common_dims cd;
common_dim_state state1{dims1, cd.axes_map1};
common_dim_state state2{dims2, cd.axes_map2};
while(not state1.is_end() and not state2.is_end())
{
auto d1 = state1.get();
auto d2 = state2.get();
if(d1 <= d2)
{
if(not compute_common_dim(cd.dims, state1, state2))
return {};
}
else // if(d1 > d2)
{
if(not compute_common_dim(cd.dims, state2, state1))
return {};
}
}
assert(elements(dims1) == elements(cd.dims));
return cd;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -24,11 +24,14 @@ ...@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator> #include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
...@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m) ...@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
} }
return changed; return changed;
} }
namespace {
struct find_pointwise_reshape_pointwise
{
auto matcher() const
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous = [](auto... ms) {
return match::arg(0)(match::skip(match::name("contiguous")(match::used_once()))(ms...));
};
auto pointwise = match::name("pointwise")(match::used_once());
auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape");
return match::name("pointwise")(match::any_of[match::inputs()](reshape_pointwise));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto cd = common_dims::compute(ins->get_shape().lens(), x_ins->get_shape().lens());
if(cd.dims.empty())
return;
auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
auto c = m.insert_instruction(ins_to_insert, make_op("contiguous"), input);
return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), c);
};
};
auto x_inputs = x_ins->inputs();
std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins));
auto new_x_ins =
m.insert_instruction(x_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs());
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == reshape_ins)
return new_x_ins;
return reshape_input(ins)(input);
});
auto pw = m.insert_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw);
}
};
} // namespace
void fuse_pointwise::apply(module_pass_manager& mpm) const void fuse_pointwise::apply(module_pass_manager& mpm) const
{ {
...@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const ...@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
} }
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(mpm.get_module(), find_pointwise_reshape_pointwise{});
mpm.run_pass(simplify_reshapes{1});
if(not find_pointwise_modules(mpm.get_module())) if(not find_pointwise_modules(mpm.get_module()))
break; break;
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct MIGRAPHX_EXPORT common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
...@@ -623,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins) ...@@ -623,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
static_assert(((not std::is_convertible<Ms, std::string>{}) and ...),
"Use a matcher not a string for skip.");
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<optional<instruction_ref>>( return fix<optional<instruction_ref>>(
......
...@@ -38,6 +38,7 @@ struct module; ...@@ -38,6 +38,7 @@ struct module;
*/ */
struct MIGRAPHX_EXPORT simplify_reshapes struct MIGRAPHX_EXPORT simplify_reshapes
{ {
size_t depth = 4;
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -1446,10 +1446,13 @@ struct find_split_transpose ...@@ -1446,10 +1446,13 @@ struct find_split_transpose
{ {
return; return;
} }
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size()); std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front(); return i->outputs().front();
}); });
......
...@@ -784,7 +784,7 @@ struct find_transpose_slice ...@@ -784,7 +784,7 @@ struct find_transpose_slice
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 4; i++) for(int i = 0; i < depth; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
......
...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr) ...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr); auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess) if(status != hipSuccess)
return false; return false;
return attr.memoryType == hipMemoryTypeDevice; return attr.type == hipMemoryTypeDevice;
} }
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
......
...@@ -93,6 +93,8 @@ struct mlir_handle ...@@ -93,6 +93,8 @@ struct mlir_handle
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); } friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y) { return not(x == y); } friend bool operator!=(ptr x, ptr y) { return not(x == y); }
explicit operator bool() const noexcept { return obj != ptr(); }
T obj{}; T obj{};
}; };
...@@ -867,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx, ...@@ -867,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx,
adjust_param_shapes(m, to_shapes(inputs)); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << m << std::endl; std::cout << m << std::endl;
}
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution); auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front(); co.output = m.get_output_shapes().front();
......
...@@ -177,6 +177,7 @@ add_dependencies(check test_tf) ...@@ -177,6 +177,7 @@ add_dependencies(check test_tf)
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(verify) add_subdirectory(verify)
add_subdirectory(ref)
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py) add_subdirectory(py)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <test.hpp>
using axes_map = std::vector<std::vector<std::size_t>>;
TEST_CASE(common_d1_less)
{
auto cd = migraphx::common_dims::compute({2, 32, 40, 8}, {2, 1280, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2}, {3}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}});
}
TEST_CASE(common1)
{
auto cd = migraphx::common_dims::compute({2, 32, 2560}, {2, 1280, 8, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2, 3, 4}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}, {4}});
}
TEST_CASE(common2)
{
auto cd = migraphx::common_dims::compute({2, 1280, 8, 8}, {2, 32, 2560});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}, {4}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1}, {2, 3, 4}});
}
TEST_CASE(common_error1)
{
auto cd = migraphx::common_dims::compute({6, 35}, {3, 7, 2, 5});
EXPECT(cd.dims.empty());
}
TEST_CASE(common_error2)
{
auto cd = migraphx::common_dims::compute({3, 7, 2, 5}, {6, 35});
EXPECT(cd.dims.empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -21,8 +21,9 @@ ...@@ -21,8 +21,9 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -361,4 +362,154 @@ TEST_CASE(no_input) ...@@ -361,4 +362,154 @@ TEST_CASE(no_input)
EXPECT(p == p2); EXPECT(p == p2);
} }
TEST_CASE(add_reshape_add)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 10, 16}};
migraphx::shape s2{migraphx::shape::float_type, {3, 40, 2, 2}};
migraphx::shape s3{migraphx::shape::float_type, {3, 10, 4, 2, 2}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto reshape =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y);
auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
auto reshape =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), fadd);
mm->add_return({reshape});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(add_reshape_add_nonstandard)
{
migraphx::shape s1 =
migraphx::shape::from_permutation(migraphx::shape::float_type, {3, 10, 16}, {2, 0, 1});
migraphx::shape s2{migraphx::shape::float_type, {3, 40, 2, 2}};
migraphx::shape s3{migraphx::shape::float_type, {3, 10, 4, 2, 2}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), add1);
auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), c);
auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cy);
auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
auto reshape =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), fadd);
mm->add_return({reshape});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(add_unsqueeze_add_nonstandard)
{
migraphx::shape s1 =
migraphx::shape::from_permutation(migraphx::shape::float_type, {3, 10, 16}, {2, 0, 1});
migraphx::shape s2{migraphx::shape::float_type, {3, 10, 1, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), unsqueeze, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cy);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(add_reshape_add_error)
{
migraphx::shape s1{migraphx::shape::float_type, {6, 35}};
migraphx::shape s2{migraphx::shape::float_type, {3, 7, 2, 5}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto reshape =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto fadd1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto reshape =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), fadd1);
auto fadd2 = add_pointwise(p2, "main:pointwise1", {reshape, z}, single_pointwise("add"));
mm->add_return({fadd2});
}
EXPECT(p1.sort() == p2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
file(GLOB REF CONFIGURE_DEPENDS *.cpp)
add_test_executable(test_ref ${REF})
target_include_directories(test_ref PUBLIC ../include)
rocm_clang_tidy_check(test_ref)
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(abs_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 2, -3, 4}});
mm->add_instruction(migraphx::make_op("abs"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(abs_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 8}, {2, 2}}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("abs"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> a = {-1, 2, -3, 4};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}};
params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(acos_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}};
std::vector<float> data{-0.8f, 0.0f, 1.0f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("acos"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(acos_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("acos"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-0.8f, 0.0f, 1.0f};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(acosh_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}};
std::vector<float> data{1.1f, 1.2f, 2.0f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("acosh"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(acosh_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.1f, 1.2f, 2.0f};
mm->add_instruction(migraphx::make_op("acosh"), input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(add_broadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraphx::shape b_shape{migraphx::shape::float_type, {2, 2}};
std::vector<float> b_data{0, -1, -2, -3};
uint64_t axis = 0;
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l1, l3);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(add_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 1}};
std::vector<float> b_data{0, -1, -2, -3};
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l1);
auto l4 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(add_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
mm->add_instruction(migraphx::make_op("add"), l1, l2);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(add_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("add"), x, y);
p.compile(migraphx::make_target("ref"));
std::vector<float> x_data{-1, 0, 1};
std::vector<float> y_data{1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(fp16_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {1}};
migraphx::half a{1.5};
migraphx::half b{2.5};
migraphx::half c{4.0};
auto l0 = mm->add_literal(migraphx::literal{s, {a}});
auto l1 = mm->add_literal(migraphx::literal{s, {b}});
mm->add_instruction(migraphx::make_op("add"), l0, l1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(fp32_fp16_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = mm->add_literal(migraphx::literal(s, data));
mm->add_instruction(migraphx::make_op("add"), l1, l2);
return p;
};
auto test_case = [&](std::vector<std::string>&& op_names) {
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program();
migraphx::quantize_fp16(p, op_names);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res, gold_res));
};
test_case({"all"});
test_case({"add"});
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(argmax_test_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 1}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_neg_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -2}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {3, 6}, {3, 6}}};
auto dl = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl);
p.compile(migraphx::make_target("ref"));
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 3, 4}};
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
auto p_uncompiled = p;
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(argmin_test_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmin", {{"axis", 0}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmin", {{"axis", 1}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_neg_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
auto p_uncompiled = p;
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment