Commit f7079e51 authored by Paul's avatar Paul
Browse files

Merge

parents 79eac1b8 f6e22d56
...@@ -195,6 +195,14 @@ constexpr auto compose(Fs... fs) ...@@ -195,6 +195,14 @@ constexpr auto compose(Fs... fs)
})(fs...); })(fs...);
} }
template <class F>
constexpr auto partial(F f)
{
return [=](auto... xs) {
return [=](auto&&... ys) { return f(xs..., static_cast<decltype(ys)>(ys)...); };
};
}
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
......
...@@ -233,6 +233,12 @@ struct index ...@@ -233,6 +233,12 @@ struct index
} }
}; };
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline __device__ __attribute__((const)) index make_index() inline __device__ __attribute__((const)) index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
......
...@@ -174,6 +174,25 @@ struct inner_storage_tag ...@@ -174,6 +174,25 @@ struct inner_storage_tag
template <class T> template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>; using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class Size, class F>
struct lazy_inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr lazy_inner_storage<Size, F> make_lazy_inner_storage(Size, F f)
{
return {{}, f};
}
template <class R, class F> template <class R, class F>
struct storage_access : F struct storage_access : F
{ {
...@@ -278,6 +297,14 @@ struct reducer_base ...@@ -278,6 +297,14 @@ struct reducer_base
}); });
} }
template <class F>
__device__ auto lazy_inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
});
}
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
...@@ -396,25 +423,6 @@ struct block_large ...@@ -396,25 +423,6 @@ struct block_large
index idx; index idx;
Slicer slice; Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class... Ts> template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{ {
...@@ -439,7 +447,7 @@ struct block_large ...@@ -439,7 +447,7 @@ struct block_large
template <class R, class F, class N, class... Ts> template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); }); return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
} }
}; };
...@@ -469,25 +477,6 @@ struct lane ...@@ -469,25 +477,6 @@ struct lane
index idx; index idx;
Slicer slice; Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class U, class... Us> template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{ {
...@@ -518,7 +507,7 @@ struct lane ...@@ -518,7 +507,7 @@ struct lane
template <class R, class F, class N, class... Ts> template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); }); return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
} }
}; };
template <class Slicer> template <class Slicer>
...@@ -577,5 +566,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu ...@@ -577,5 +566,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
}); });
} }
template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f)
{
Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r);
if constexpr(reduce::is_inner_storage<decltype(result)>{})
{
r.inner([&](auto& y, auto x) { y = x; })(output, result);
}
else
{
r.outer([&] { output[out_idx] = implicit_conversion(result); });
}
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP #endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
...@@ -26,13 +26,13 @@ ...@@ -26,13 +26,13 @@
#include <migraphx/check_context.hpp> #include <migraphx/check_context.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp> #include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp> #include <migraphx/layout_nhwc.hpp>
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp> #include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/promote_literals.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_gelu.hpp> #include <migraphx/rewrite_gelu.hpp>
...@@ -48,9 +48,9 @@ ...@@ -48,9 +48,9 @@
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp> #include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
...@@ -73,6 +73,7 @@ namespace gpu { ...@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass struct id_pass
{ {
...@@ -101,6 +102,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -101,6 +102,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// clang-format off // clang-format off
return return
{ {
enable_pass(options.split_single_dyn_dim, split_single_dyn_dim{}),
enable_pass(options.split_single_dyn_dim, dead_code_elimination{}),
normalize_ops{}, normalize_ops{},
dead_code_elimination{}, dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
...@@ -128,6 +131,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -128,6 +131,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{}, optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
fuse_mlir{&ctx}, fuse_mlir{&ctx},
dead_code_elimination{}, dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
...@@ -147,6 +152,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -147,6 +152,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
promote_literals{},
dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
......
...@@ -31,10 +31,9 @@ set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref) ...@@ -31,10 +31,9 @@ set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h) find_path(BLAZE_INCLUDE blaze/Blaze.h)
find_package(Threads)
rocm_clang_tidy_check(migraphx_ref) rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref migraphx Threads::Threads) target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
......
...@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs) ...@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs)
}; };
} }
template <class Op>
struct ref_convolution : auto_register_op<ref_convolution<Op>>
{
ref_convolution() = default;
ref_convolution(Op pop) : op(std::move(pop)) {}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> padding;
if(op.padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
padding =
op.padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), padding, op.stride, op.dilation);
}
else
{
padding = op.padding;
if(output_shape.dynamic())
{
output_shape =
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0];
auto wei_c = wei_lens[1];
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto w = idx_o[1];
auto n_dim = idx_o.size();
std::vector<std::ptrdiff_t> win_start;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) -
std::ptrdiff_t(padding[d_2]));
}
const auto group_id = w / (wei_n / op.group);
shape win_shape{output_shape.type(), win_size};
double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) {
auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch;
std::transform(idx_win.begin() + 1,
idx_win.end(),
win_start.begin(),
idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(),
idx.end(),
in_lens.begin(),
in_lens.end(),
std::less<std::ptrdiff_t>{}))
{
acc +=
input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
}
});
output[i] = acc;
});
});
return result;
}
};
struct ref_im2col struct ref_im2col
{ {
op::im2col op; op::im2col op;
...@@ -564,11 +461,8 @@ struct ref_apply ...@@ -564,11 +461,8 @@ struct ref_apply
void init() void init()
{ {
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>(); apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>(); apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>(); apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>(); apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
......
...@@ -218,3 +218,10 @@ test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/r ...@@ -218,3 +218,10 @@ test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/r
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
test_headers(migraphx/gpu ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/*.hpp) test_headers(migraphx/gpu ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/*.hpp)
endif() endif()
if(MIGRAPHX_ENABLE_CPU)
test_headers(migraphx/cpu ${CMAKE_SOURCE_DIR}/src/targets/cpu/include/migraphx/cpu/*.hpp)
endif()
if(MIGRAPHX_ENABLE_FPGA)
test_headers(migraphx/fpga ${CMAKE_SOURCE_DIR}/src/targets/fpga/include/migraphx/fpga/*.hpp)
endif()
...@@ -36,7 +36,7 @@ bool create_shapes(bool dynamic_allowed) ...@@ -36,7 +36,7 @@ bool create_shapes(bool dynamic_allowed)
try try
{ {
shape a{shape::int64_type, {3}}; shape a{shape::int64_type, {3}};
shape b{shape::float_type, {{3, 6, 0}, {4, 4, 0}}}; shape b{shape::float_type, {{3, 6}, {4, 4}}};
auto op = migraphx::make_op("add"); auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2); migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true; return true;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}});
}
bool all_instructions_are_local(const migraphx::module& m)
{
return std::all_of(m.begin(), m.end(), [&](const auto& ins) {
return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) {
return m.has_instruction(input);
});
});
}
template <class F>
migraphx::instruction_ref add_reduce(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
const std::vector<int64_t>& axes,
F f)
{
auto* rm = p.create_module(name);
auto* mm = p.get_main_module();
rm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return rm->add_parameter(
"x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type(), input->get_shape().lens()});
});
auto r = f(rm, params, axes);
rm->add_return({r});
EXPECT(all_instructions_are_local(*rm));
return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm});
}
inline auto single_reduce(const std::string& name)
{
return [=](auto* rm, const auto& inputs, const auto& axes) {
return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs);
};
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y);
mm->add_return({rsum1, rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum"));
mm->add_return({rsum1, rsum2});
}
EXPECT(p1 == p2);
}
TEST_CASE(pointwise_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add);
mm->add_return({rsum});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = add_reduce(
p2,
"main:pointwise0:main:reduce_sum0",
{x, y},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto add =
add_pointwise(p2, rm, "main:pointwise0", inputs, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_reduce(
p2,
"main:reduce_sum0:main:pointwise0",
{x, y},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
return add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add"));
});
mm->add_return({add});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto rsumdiff = add_pointwise(p1, "main:pointwise0", {rsumb, x}, single_pointwise("sub"));
auto rsum2 =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), rsumdiff);
auto sqrt = add_pointwise(p1, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
mm->add_return({sqrt});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto sqrt = add_reduce(
p2,
"main:reduce_sum1:main:reduce_sum0:main:pointwise0:main:pointwise1",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto rsumdiff = add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[0]}, single_pointwise("sub"));
auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
rsumdiff);
return add_pointwise(p2, rm, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
});
mm->add_return({sqrt});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce_mismatch_axis)
{
migraphx::shape s{migraphx::shape::float_type, {4, 2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsum1);
mm->add_return({rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {rsum1}, {2}, single_reduce("reduce_sum"));
mm->add_return({rsum2});
}
EXPECT(p1 == p2);
}
TEST_CASE(pointwise_reduce_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto sqrt = add_pointwise(p1, "main:pointwise0", {rsum1}, single_pointwise("sqrt"));
auto sqrtb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt);
auto add1 = add_pointwise(p1, "main:pointwise1", {sqrtb, x}, single_pointwise("add"));
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add1);
auto add2 = add_pointwise(p1, "main:pointwise2", {rsum2, rsum1}, single_pointwise("add"));
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add2 = add_reduce(
p2,
"main:pointwise0:main:pointwise1:main:reduce_sum1:main:pointwise2:main:reduce_sum0",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto sqrt =
add_pointwise(p2, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt"));
auto sqrtb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt);
auto add1 = add_pointwise(
p2, rm, "main:pointwise1", {sqrtb, inputs[0]}, single_pointwise("add"));
auto rsum2 =
rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add1);
return add_pointwise(
p2, rm, "main:pointwise2", {rsum2, rsum1}, single_pointwise("add"));
});
mm->add_return({add2});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {4, 2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = add_reduce(p1, "test:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1);
auto add = add_reduce(
p1,
"test:reduce_sum1",
{rsumb, x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto add2 =
add_pointwise(p1, rm, "test:pointwise0", inputs, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add2);
});
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = add_reduce(
p2,
"test:reduce_sum1:test:reduce_sum0",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1);
auto add = add_pointwise(
p2, rm, "test:pointwise0", {rsumb, inputs[0]}, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
TEST_CASE(tuple_to_from_gpu) TEST_CASE(tuple_from_gpu)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
...@@ -47,4 +47,23 @@ TEST_CASE(tuple_to_from_gpu) ...@@ -47,4 +47,23 @@ TEST_CASE(tuple_to_from_gpu)
EXPECT(result2 == p2_data); EXPECT(result2 == p2_data);
} }
TEST_CASE(tuple_to_gpu)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
std::vector<float> p1_data = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6};
std::vector<int> p2_data = {1, 2, 3, 4, 5, 6, 7, 8};
auto p1 = migraphx::argument{s1, p1_data.data()};
auto p2 = migraphx::argument{s2, p2_data.data()};
auto p_gpu = migraphx::gpu::to_gpu(migraphx::argument({p1, p2}));
auto p_host = migraphx::gpu::from_gpu(p_gpu);
std::vector<migraphx::argument> results = p_host.get_sub_objects();
std::vector<float> result1;
results[0].visit([&](auto output) { result1.assign(output.begin(), output.end()); });
std::vector<int> result2;
results[1].visit([&](auto output) { result2.assign(output.begin(), output.end()); });
EXPECT(result1 == p1_data);
EXPECT(result2 == p2_data);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -30,12 +30,12 @@ ...@@ -30,12 +30,12 @@
template <class F> template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p, migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name, const std::string& name,
std::vector<migraphx::instruction_ref> inputs, std::vector<migraphx::instruction_ref> inputs,
F f) F f)
{ {
auto* pm = p.create_module(name); auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass(); pm->set_bypass();
std::vector<migraphx::instruction_ref> params; std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
...@@ -47,6 +47,15 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p, ...@@ -47,6 +47,15 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm}); return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
} }
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
return add_pointwise(p, p.get_main_module(), name, inputs, f);
}
inline auto single_pointwise(const std::string& name) inline auto single_pointwise(const std::string& name)
{ {
return [=](auto* pm, const auto& inputs) { return [=](auto* pm, const auto& inputs) {
......
This diff is collapsed.
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_promote(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::promote_literals{}, migraphx::dead_code_elimination{}});
}
void run_promote_and_ecs(migraphx::program& p)
{
migraphx::run_passes(p,
{migraphx::promote_literals{},
migraphx::dead_code_elimination{},
migraphx::eliminate_common_subexpression{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(promote_only)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
run_promote(p0);
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins3 = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins2 = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins1 = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins0 = mm1->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size,
migraphx::instruction_ref lit,
const std::string& module_name) {
auto* submod = p1.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), lit, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, literal_ins0, "dim_1");
auto* dim2 = create_submodule(2, literal_ins1, "dim_2");
auto* dim3 = create_submodule(3, literal_ins2, "dim_3");
auto* dim4 = create_submodule(4, literal_ins3, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->insert_parameter(std::next(literal_ins3), "data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm1->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm1->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm1->add_return({ret});
}
EXPECT(p0 == p1);
}
TEST_CASE(promote_and_ecs0)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
run_promote_and_ecs(p0);
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p1.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->insert_parameter(std::next(literal_ins), "data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm1->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm1->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm1->add_return({ret});
}
EXPECT(p0 == p1);
}
TEST_CASE(promote_and_ecs1)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins0 = submod->add_literal(migraphx::literal{lit_s, {6}});
auto literal_ins1 = submod->add_literal(migraphx::literal{lit_s, {2}});
auto broadcast_lit0 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins0, sm_input);
auto broadcast_lit1 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins1, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit0);
auto mul_ins =
submod->add_instruction(migraphx::make_op("mul"), add_ins, broadcast_lit1);
submod->add_return({mul_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
run_promote_and_ecs(p0);
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins1 = mm1->add_literal(migraphx::literal{lit_s, {2}});
auto literal_ins0 = mm1->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p1.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit0 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins0, sm_input);
auto broadcast_lit1 = submod->add_instruction(
migraphx::make_op("multibroadcast"), literal_ins1, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit0);
auto mul_ins =
submod->add_instruction(migraphx::make_op("mul"), add_ins, broadcast_lit1);
submod->add_return({mul_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm1->insert_parameter(std::next(literal_ins1), "data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm1->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2});
auto ret =
mm1->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm1->add_return({ret});
}
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1197,7 +1197,7 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1197,7 +1197,7 @@ TEST_CASE(dot_dyn_2D_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}}; migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
auto ap = mm->add_parameter("a", a_shape); auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape); auto bp = mm->add_parameter("b", b_shape);
...@@ -1250,8 +1250,7 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1250,8 +1250,7 @@ TEST_CASE(dot_dyn_4D_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, migraphx::shape a_shape{migraphx::shape::float_type, {{1, 1}, {1, 1}, {4, 6, {4}}, {5, 5}}};
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape); auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape); auto bl = mm->add_parameter("b", b_shape);
......
This diff is collapsed.
...@@ -33,12 +33,20 @@ ...@@ -33,12 +33,20 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/verify.hpp>
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }
migraphx::argument eval(const migraphx::program& p)
{
auto r = p.eval({});
EXPECT(r.size() == 1);
return r.front();
}
TEST_CASE(quantizelinear) TEST_CASE(quantizelinear)
{ {
...@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear) ...@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt; run_pass(*p2.get_main_module());
opt.apply(*p2.get_main_module()); EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
} }
...@@ -71,9 +79,9 @@ TEST_CASE(dequantizelinear) ...@@ -71,9 +79,9 @@ TEST_CASE(dequantizelinear)
std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250}; std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250};
migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2};
migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}}; migraphx::shape zs{migraphx::shape::float_type, {1, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<float> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() { auto create_program = [&]() {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv); auto x = mm->add_literal(xs, xv);
...@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear) ...@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt; run_pass(*p2.get_main_module());
opt.apply(*p2.get_main_module()); EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear));
} }
......
...@@ -41,22 +41,13 @@ TEST_CASE(test_shape_default) ...@@ -41,22 +41,13 @@ TEST_CASE(test_shape_default)
TEST_CASE(test_dyn_4arg_constructor) TEST_CASE(test_dyn_4arg_constructor)
{ {
migraphx::shape s{migraphx::shape::float_type, migraphx::shape s0{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {{}, {}, {}}};
{ migraphx::shape s1{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {}};
1, std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {{1, 4}, {4, 4}, {4, 4}};
4, EXPECT(s0.dynamic());
4, EXPECT(s0.dyn_dims() == expected_dyn_dims);
}, EXPECT(s1.dynamic());
{ EXPECT(s1.dyn_dims() == expected_dyn_dims);
4,
4,
4,
},
{0, 0, 0}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {
{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
EXPECT(s.dynamic());
EXPECT(s.dyn_dims() == expected_dyn_dims);
} }
TEST_CASE(test_shape_assign) TEST_CASE(test_shape_assign)
...@@ -85,17 +76,26 @@ TEST_CASE(test_shape_standard) ...@@ -85,17 +76,26 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_standard_singleton_dim)
{
migraphx::shape s{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_min_max_opt) TEST_CASE(test_shape_min_max_opt)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.min_lens() == s.lens()); EXPECT(s.min_lens() == s.lens());
EXPECT(s.max_lens() == s.lens()); EXPECT(s.max_lens() == s.lens());
EXPECT(s.opt_lens() == s.lens()); EXPECT(s.opt_lens().empty());
} }
TEST_CASE(test_shape_dynamic_fixed) TEST_CASE(test_shape_dynamic_fixed)
{ {
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {2, 2}, {3, 3}}};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -106,7 +106,8 @@ TEST_CASE(test_shape_dynamic_fixed) ...@@ -106,7 +106,8 @@ TEST_CASE(test_shape_dynamic_fixed)
EXPECT(not s.dyn_dims().at(0).has_optimal()); EXPECT(not s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3}); EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3}); EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.opt_lens() == std::vector<std::size_t>{0, 0, 0}); std::vector<std::set<std::size_t>> e_opt_lens = {{}, {}, {}};
EXPECT(s.opt_lens() == e_opt_lens);
EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float)); EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float));
} }
...@@ -114,8 +115,8 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -114,8 +115,8 @@ TEST_CASE(test_shape_dynamic_not_fixed)
{ {
using migraphx::shape; using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {}; std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2}); dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8, 0}); dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
...@@ -127,18 +128,16 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -127,18 +128,16 @@ TEST_CASE(test_shape_dynamic_not_fixed)
EXPECT(s.dyn_dims().at(0).has_optimal()); EXPECT(s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2}); EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2});
EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8}); EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8});
EXPECT(s.opt_lens() == std::vector<std::size_t>{2, 0}); EXPECT(s.opt_lens() == std::vector<std::set<std::size_t>>{{2}, {}});
EXPECT(s.bytes() == 5 * 8 * sizeof(float)); EXPECT(s.bytes() == 5 * 8 * sizeof(float));
} }
TEST_CASE(test_shape_dynamic_compares) TEST_CASE(test_shape_dynamic_compares)
{ {
using migraphx::shape; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2}; auto a = shape::dynamic_dimension{2, 5, {2}};
auto b = a; auto c = shape::dynamic_dimension{2, 5, {2}};
auto c = shape::dynamic_dimension{2, 5, 2}; auto d = shape::dynamic_dimension{3, 8};
auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
EXPECT(a == c); EXPECT(a == c);
EXPECT(a != d); EXPECT(a != d);
...@@ -163,13 +162,13 @@ TEST_CASE(test_shape_dynamic_compares) ...@@ -163,13 +162,13 @@ TEST_CASE(test_shape_dynamic_compares)
TEST_CASE(dynamic_dimension_size_t_compares) TEST_CASE(dynamic_dimension_size_t_compares)
{ {
using migraphx::shape; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2}; auto a = shape::dynamic_dimension{2, 2, {2}};
EXPECT(a == 2); EXPECT(a == 2);
EXPECT(a != 3); EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a); EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a); EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0}; auto b = shape::dynamic_dimension{2, 4};
EXPECT(b != 2); EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b); EXPECT(static_cast<std::size_t>(2) != b);
} }
...@@ -177,25 +176,25 @@ TEST_CASE(dynamic_dimension_size_t_compares) ...@@ -177,25 +176,25 @@ TEST_CASE(dynamic_dimension_size_t_compares)
TEST_CASE(dynamic_dimension_add_sub_fixed) TEST_CASE(dynamic_dimension_add_sub_fixed)
{ {
using migraphx::shape; using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2}; auto a = shape::dynamic_dimension{2, 5, {2}};
a += 3; a += 3;
EXPECT(a == shape::dynamic_dimension{5, 8, 5}); EXPECT(a == shape::dynamic_dimension{5, 8, {5}});
a -= 3; a -= 3;
EXPECT(a == shape::dynamic_dimension{2, 5, 2}); EXPECT(a == shape::dynamic_dimension{2, 5, {2}});
auto b = shape::dynamic_dimension{3, 6, 3}; auto b = shape::dynamic_dimension{3, 6, {3}};
EXPECT((a + 1) == b); EXPECT((a + 1) == b);
EXPECT((1 + a) == b); EXPECT((1 + a) == b);
EXPECT((b - 1) == a); EXPECT((b - 1) == a);
auto c = shape::dynamic_dimension{4, 7, 4}; auto c = shape::dynamic_dimension{4, 7, {4}};
EXPECT((a + 2) == c); EXPECT((a + 2) == c);
EXPECT((2 + a) == c); EXPECT((2 + a) == c);
EXPECT((c - 2) == a); EXPECT((c - 2) == a);
auto d = shape::dynamic_dimension{4, 8, 0}; auto d = shape::dynamic_dimension{4, 8};
auto e = shape::dynamic_dimension{2, 6, 0}; auto e = shape::dynamic_dimension{2, 6};
EXPECT((d - 2) == e); EXPECT((d - 2) == e);
EXPECT((e + 2) == d); EXPECT((e + 2) == d);
EXPECT((2 + e) == d); EXPECT((2 + e) == d);
...@@ -205,8 +204,8 @@ TEST_CASE(test_shape_dynamic_errors) ...@@ -205,8 +204,8 @@ TEST_CASE(test_shape_dynamic_errors)
{ {
using migraphx::shape; using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {}; std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2}); dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8, 0}); dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{shape::float_type, dims}; migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); })); EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.index({0, 1}); })); EXPECT(test::throws([&] { s.index({0, 1}); }));
...@@ -220,13 +219,13 @@ TEST_CASE(test_shape_dynamic_serialize) ...@@ -220,13 +219,13 @@ TEST_CASE(test_shape_dynamic_serialize)
{ {
using migraphx::shape; using migraphx::shape;
std::vector<shape::dynamic_dimension> dims1 = {}; std::vector<shape::dynamic_dimension> dims1 = {};
dims1.push_back(shape::dynamic_dimension{2, 5, 2}); dims1.push_back(shape::dynamic_dimension{2, 5, {2}});
dims1.push_back(shape::dynamic_dimension{2, 8, 0}); dims1.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s1{shape::float_type, dims1}; migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1); auto v1 = migraphx::to_value(s1);
std::vector<shape::dynamic_dimension> dims2 = {}; std::vector<shape::dynamic_dimension> dims2 = {};
dims2.push_back(shape::dynamic_dimension{2, 5, 2}); dims2.push_back(shape::dynamic_dimension{2, 5, {2}});
migraphx::shape s2{shape::uint64_type, dims2}; migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2); auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2); EXPECT(v1 != v2);
...@@ -285,14 +284,13 @@ TEST_CASE(test_shape_ndim_static) ...@@ -285,14 +284,13 @@ TEST_CASE(test_shape_ndim_static)
TEST_CASE(test_shape_ndim_dyn) TEST_CASE(test_shape_ndim_dyn)
{ {
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}}; migraphx::shape s0{migraphx::shape::float_type, {{2, 2}, {2, 2}}};
EXPECT(s0.ndim() == 2); EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}}; migraphx::shape s1{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(s1.ndim() == 4); EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type, migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {1, 1}, {3, 3}}};
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
EXPECT(s2.ndim() == 5); EXPECT(s2.ndim() == 5);
} }
...@@ -327,17 +325,60 @@ TEST_CASE(test_shape_static_to_dynamic) ...@@ -327,17 +325,60 @@ TEST_CASE(test_shape_static_to_dynamic)
{ {
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}}; migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic(); migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}}; migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}};
EXPECT(s1 == s2); EXPECT(s1 == s2);
} }
TEST_CASE(test_shape_dyn_to_dynamic) TEST_CASE(test_shape_dyn_to_dynamic)
{ {
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}}; migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
migraphx::shape s1 = s0.to_dynamic(); migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1); EXPECT(s0 == s1);
} }
TEST_CASE(test_shape_subshapes_to_dynamic)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_dynamic();
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 2}, {2, 10}, {2, 10}}};
migraphx::shape s1 = s0.to_static(4);
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 4, 4}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_static_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_static(8);
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_subshapes_to_static)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_static(3);
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_overlap) TEST_CASE(test_shape_overlap)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
......
...@@ -509,6 +509,34 @@ TEST_CASE(simplify_dot_add) ...@@ -509,6 +509,34 @@ TEST_CASE(simplify_dot_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_conv_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto c = m1.add_literal(migraphx::generate_literal(s, 1));
auto w = m1.add_literal(migraphx::generate_literal(ws, 2));
auto sum = m1.add_instruction(migraphx::make_op("add"), c, x);
auto conv = m1.add_instruction(migraphx::make_op("convolution"), sum, w);
m1.add_instruction(pass_op{}, conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto c = m2.add_literal(migraphx::generate_literal(s, 1));
auto w = m2.add_literal(migraphx::generate_literal(ws, 2));
auto conv1 = m2.add_instruction(migraphx::make_op("convolution"), c, w);
auto conv2 = m2.add_instruction(migraphx::make_op("convolution"), x, w);
auto sum = m2.add_instruction(migraphx::make_op("add"), conv1, conv2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1) TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment