Unverified Commit f00a0b9b authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into layernorm_eps

parents 86713e78 97a1ed2d
......@@ -26,8 +26,6 @@ on:
required: true
default: '-s'
concurrency: benchmark
jobs:
release:
uses: rocmsoftwareplatform/migraphx-benchmark/.github/workflows/perf-test.yml@main
......
......@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@d2cb9e580550e92ab75a0a417e7a4abd02a24edf -DBUILD_MIXR_TARGET=On
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -84,6 +84,12 @@ argument
Construct an argument from a python buffer. This can include numpy arrays.
.. py:method:: data_ptr()
Returns the address to the underlying argument data.
:rtype: int
.. py:method:: get_shape()
Returns the shape of the argument.
......@@ -113,7 +119,16 @@ argument
:param shape s: Shape of argument to fill.
:param int value: Value to fill in the argument.
:rtype argument
:rtype: argument
.. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy.
:param shape shape: Shape of the data stored in address.
:param long address: Memory address of data from another source
:rtype: argument
target
------
......
......@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__",
[](migraphx::argument& x, py::buffer b) {
.def(py::init([](py::buffer b) {
py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr);
})
return migraphx::argument(to_shape(info), info.ptr);
}))
.def("get_shape", &migraphx::argument::get_shape)
.def("data_ptr",
[](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); })
.def("tolist",
[](migraphx::argument& x) {
py::list l{x.get_shape().elements()};
......
......@@ -985,20 +985,35 @@ struct find_split_reshape
auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end())
{
return;
}
int rsp_axis = std::distance(rsp_strides.begin(), ait);
else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis];
});
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction and add contiguous if needed
......
......@@ -271,6 +271,44 @@ struct find_nested_slice
}
};
struct find_concat_multibroadcasts
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast")));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
}))
{
return;
}
// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
return i->inputs().front();
});
// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
}
};
struct find_concat_transpose
{
auto matcher() const
......@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const
find_reshaper{},
find_transpose{},
find_concat_transpose{},
find_concat_multibroadcasts{},
find_nested_convert{},
find_nested_slice{},
find_nested_concat{},
......
......@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR)
find_library(MLIRAPI_LIBRARY MLIRMIOpen
PATH_SUFFIXES
# Workaournd broken mlir install
lib/ lib/lib)
# REQUIRED is not supported before cmake 3.18
if(NOT MLIRAPI_LIBRARY)
message(FATAL_ERROR "libMLIRMIOpen not found")
else()
message(STATUS "Build with libMLIRMIOpen: " ${MLIRAPI_LIBRARY})
endif()
find_path(MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h)
# Workaround MLIR broken installation
find_path(MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include)
# Find package rocMLIR
find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_include_directories(migraphx_gpu SYSTEM PRIVATE ${MLIRAPI_HEADERS} ${MLIRAPI_HEADERS2})
target_link_libraries(migraphx_gpu PUBLIC ${MLIRAPI_LIBRARY})
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
......
......@@ -259,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
};
MIGRAPHX_REGISTER_OP(hip_add_relu)
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
struct hip_add_sigmoid : binary_device<hip_add_sigmoid, &device::add_sigmoid>
{
};
MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
......
......@@ -176,8 +176,13 @@ void gemm_impl(context& ctx,
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
// NOLINTNEXTLINE
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
return inputs.back().elements() / (inputs.size() - 1);
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#ifndef MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
#define MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
namespace migraphx {
template <index_int Axis, class Output, class Input, class Start>
constexpr auto concat_slice(Output out, Input, Start)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
constexpr auto strides = get_shape_c<Output>{}.strides;
constexpr auto offset = return_c([] {
constexpr auto output_shape = get_shape_c<Output>{};
return Start{} * output_shape.strides[Axis];
});
constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s);
}
template <index_int Axis, class Input>
constexpr auto concat_ends(Input)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
return _c<lens[Axis]>;
}
template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs)
{
auto idx = make_index();
fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
......@@ -63,6 +63,7 @@ __device__ void generic_binary_layernorm(
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...);
......
......@@ -151,7 +151,6 @@ struct miopen_apply
add_extend_op("argmax");
add_extend_op("argmin");
add_extend_op("clip");
add_extend_op("concat");
add_extend_op("convert");
add_extend_op("elu");
add_extend_op("gather");
......
......@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/dead_code_elimination.hpp"
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
......
......@@ -144,7 +144,7 @@ TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
module {
func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32>
}
......@@ -167,7 +167,7 @@ TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
module {
func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......
......@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/instruction_ref.hpp"
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
......
......@@ -2077,6 +2077,55 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_len_1)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {1, 128, 3}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {2}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), input);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 128};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret});
};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 128, 3}};
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 384};
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {128}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {128}}, {"ends", {256}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {256}}, {"ends", {384}}}), rsp);
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_not_apply)
{
auto create_p = [] {
......
......@@ -48,6 +48,26 @@ inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx:
return result;
}
migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
const std::vector<size_t>& mbcast_lens,
const int axis)
{
migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, in_lens};
auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s);
auto z = m.add_parameter("z", s);
auto xm =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), x);
auto ym =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), y);
auto zm =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), z);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), xm, ym, zm);
m.add_return({concat});
return m;
}
TEST_CASE(double_contig)
{
migraphx::program p;
......@@ -337,6 +357,87 @@ TEST_CASE(nop_convert)
EXPECT(std::distance(m.begin(), m.end()) == n - 1);
}
TEST_CASE(concat_multibroadcasts1)
{
// Broadcasted batch dim, new axis < old axis
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 2;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
TEST_CASE(concat_multibroadcasts2)
{
// Broadcasted middle dim, new axis == old axis
std::vector<std::size_t> in_lens = {3, 1, 4};
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
const int axis = 0;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
}
TEST_CASE(concat_multibroadcasts3)
{
// Broadcasted middle dim, new axis == old axis
std::vector<std::size_t> in_lens = {3, 1, 4};
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
const int axis = 2;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
}
TEST_CASE(concat_multibroadcasts4)
{
// Broadcasted batch dim, axis is broadcasted dim
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 0;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto m1 = m;
run_pass(m);
EXPECT(m1 == m);
}
TEST_CASE(concat_transpose1)
{
migraphx::module m;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_concat_axis_2 : verify_program<test_concat_axis_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s2{migraphx::shape::int32_type, {3, 2, 1}};
auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l0, l1, l2);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
l2);
auto l3 = mm->add_literal(migraphx::generate_literal(m2_shape));
l3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
l3);
auto l4 = mm->add_literal(migraphx::generate_literal(m2_shape));
l4 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
l4);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l2, l3, l4);
auto l5 = mm->add_parameter("3", m3_shape);
float alpha = 1.0f;
float beta = 1.0f;
migraphx::add_apply_alpha_beta(
*mm, {l1, concat, l5}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment