Commit 8109aac8 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into jit-concat

parents b8e448b5 10f37f49
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@d2cb9e580550e92ab75a0a417e7a4abd02a24edf -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -264,11 +264,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -264,11 +264,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__", .def(py::init([](py::buffer b) {
[](migraphx::argument& x, py::buffer b) { py::buffer_info info = b.request();
py::buffer_info info = b.request(); return migraphx::argument(to_shape(info), info.ptr);
new(&x) migraphx::argument(to_shape(info), info.ptr); }))
})
.def("get_shape", &migraphx::argument::get_shape) .def("get_shape", &migraphx::argument::get_shape)
.def("data_ptr", .def("data_ptr",
[](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); }) [](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); })
......
...@@ -985,20 +985,35 @@ struct find_split_reshape ...@@ -985,20 +985,35 @@ struct find_split_reshape
auto rsp_lens = rsp->get_shape().lens(); auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides(); auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]); rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end()) if(ait == rsp_strides.end())
{ {
return; return;
} }
int rsp_axis = std::distance(rsp_strides.begin(), ait); else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape // calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size()); std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis]; return is->get_shape().lens()[rsp_axis];
}); });
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
......
...@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}") ...@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
find_library(MLIRAPI_LIBRARY MLIRMIOpen # Find package rocMLIR
PATH_SUFFIXES find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
# Workaournd broken mlir install message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
lib/ lib/lib)
# REQUIRED is not supported before cmake 3.18
if(NOT MLIRAPI_LIBRARY)
message(FATAL_ERROR "libMLIRMIOpen not found")
else()
message(STATUS "Build with libMLIRMIOpen: " ${MLIRAPI_LIBRARY})
endif()
find_path(MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h)
# Workaround MLIR broken installation
find_path(MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR") target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_include_directories(migraphx_gpu SYSTEM PRIVATE ${MLIRAPI_HEADERS} ${MLIRAPI_HEADERS2}) target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
target_link_libraries(migraphx_gpu PUBLIC ${MLIRAPI_LIBRARY})
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
......
...@@ -259,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu> ...@@ -259,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
}; };
MIGRAPHX_REGISTER_OP(hip_add_relu) MIGRAPHX_REGISTER_OP(hip_add_relu)
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid> struct hip_add_sigmoid : binary_device<hip_add_sigmoid, &device::add_sigmoid>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_sigmoid) MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/dead_code_elimination.hpp" #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
......
...@@ -144,7 +144,7 @@ TEST_CASE(conv) ...@@ -144,7 +144,7 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} { func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32> return %0 : tensor<1x2x2x2xf32>
} }
...@@ -167,7 +167,7 @@ TEST_CASE(conv_add_relu) ...@@ -167,7 +167,7 @@ TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} { func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/instruction_ref.hpp" #include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
......
...@@ -2077,6 +2077,55 @@ TEST_CASE(reorder_reshape_slice_move_axis2) ...@@ -2077,6 +2077,55 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
TEST_CASE(reorder_reshape_slice_len_1)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {1, 128, 3}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {2}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), input);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 128};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret});
};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 128, 3}};
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 384};
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {128}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {128}}, {"ends", {256}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {256}}, {"ends", {384}}}), rsp);
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_not_apply) TEST_CASE(reorder_reshape_slice_not_apply)
{ {
auto create_p = [] { auto create_p = [] {
......
...@@ -33,18 +33,18 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1> ...@@ -33,18 +33,18 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 384, 768}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {768, 768}}; migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m3_shape{migraphx::shape::float_type, {4, 384, 2304}}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape)); auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}), l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
l2); l2);
auto l3 = mm->add_literal(migraphx::generate_literal(m2_shape)); auto l3 = mm->add_literal(migraphx::generate_literal(m2_shape));
l3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}), l3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
l3); l3);
auto l4 = mm->add_literal(migraphx::generate_literal(m2_shape)); auto l4 = mm->add_literal(migraphx::generate_literal(m2_shape));
l4 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}), l4 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
l4); l4);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l2, l3, l4); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l2, l3, l4);
......
...@@ -33,11 +33,11 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2> ...@@ -33,11 +33,11 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 384, 768}}; migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {768, 768}}; migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape)); auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}), l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
l2); l2);
mm->add_instruction(migraphx::make_op("dot"), l1, l2); mm->add_instruction(migraphx::make_op("dot"), l1, l2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment