"vscode:/vscode.git/clone" did not exist on "627d8ef35a6da8ad268b5197e3045ccdfb4ac684"
Commit 74bd6d61 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-concat' into jit-concat-pointwise

parents b30c3408 8109aac8
...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES) ...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.3) rocm_setup_version(VERSION 2.4)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@d2cb9e580550e92ab75a0a417e7a4abd02a24edf -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -84,6 +84,12 @@ argument ...@@ -84,6 +84,12 @@ argument
Construct an argument from a python buffer. This can include numpy arrays. Construct an argument from a python buffer. This can include numpy arrays.
.. py:method:: data_ptr()
Returns the address to the underlying argument data.
:rtype: int
.. py:method:: get_shape() .. py:method:: get_shape()
Returns the shape of the argument. Returns the shape of the argument.
...@@ -113,7 +119,16 @@ argument ...@@ -113,7 +119,16 @@ argument
:param shape s: Shape of argument to fill. :param shape s: Shape of argument to fill.
:param int value: Value to fill in the argument. :param int value: Value to fill in the argument.
:rtype argument :rtype: argument
.. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy.
:param shape shape: Shape of the data stored in address.
:param long address: Memory address of data from another source
:rtype: argument
target target
------ ------
......
...@@ -24,17 +24,8 @@ ...@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP #define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,17 +24,8 @@ ...@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP #define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__", .def(py::init([](py::buffer b) {
[](migraphx::argument& x, py::buffer b) { py::buffer_info info = b.request();
py::buffer_info info = b.request(); return migraphx::argument(to_shape(info), info.ptr);
new(&x) migraphx::argument(to_shape(info), info.ptr); }))
})
.def("get_shape", &migraphx::argument::get_shape) .def("get_shape", &migraphx::argument::get_shape)
.def("data_ptr",
[](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); })
.def("tolist", .def("tolist",
[](migraphx::argument& x) { [](migraphx::argument& x) {
py::list l{x.get_shape().elements()}; py::list l{x.get_shape().elements()};
......
...@@ -1001,20 +1001,35 @@ struct find_split_reshape ...@@ -1001,20 +1001,35 @@ struct find_split_reshape
auto rsp_lens = rsp->get_shape().lens(); auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides(); auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]); rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end()) if(ait == rsp_strides.end())
{ {
return; return;
} }
int rsp_axis = std::distance(rsp_strides.begin(), ait); else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape // calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size()); std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis]; return is->get_shape().lens()[rsp_axis];
}); });
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
......
...@@ -271,6 +271,44 @@ struct find_nested_slice ...@@ -271,6 +271,44 @@ struct find_nested_slice
} }
}; };
struct find_concat_multibroadcasts
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast")));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
}))
{
return;
}
// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
return i->inputs().front();
});
// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
}
};
struct find_concat_transpose struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
...@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const
find_reshaper{}, find_reshaper{},
find_transpose{}, find_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{},
find_nested_convert{}, find_nested_convert{},
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
......
...@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}") ...@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
find_library(MLIRAPI_LIBRARY MLIRMIOpen # Find package rocMLIR
PATH_SUFFIXES find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
# Workaournd broken mlir install message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
lib/ lib/lib)
# REQUIRED is not supported before cmake 3.18
if(NOT MLIRAPI_LIBRARY)
message(FATAL_ERROR "libMLIRMIOpen not found")
else()
message(STATUS "Build with libMLIRMIOpen: " ${MLIRAPI_LIBRARY})
endif()
find_path(MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h)
# Workaround MLIR broken installation
find_path(MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR") target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_include_directories(migraphx_gpu SYSTEM PRIVATE ${MLIRAPI_HEADERS} ${MLIRAPI_HEADERS2}) target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
target_link_libraries(migraphx_gpu PUBLIC ${MLIRAPI_LIBRARY})
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
......
...@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t { [&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis]; auto stride = input.strides()[axis];
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(not contains({0, 1}, stride))
return 1; return 1;
if(len == 1 and input.elements() > sizes.front()) if(len == 1 and input.elements() > sizes.front())
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(sizes.begin(), sizes.end(), [&](auto vsize) {
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); // The len is divisible by the size and all the strides are divisible by
// the size
return (len % vsize) == 0 and
std::all_of(
input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end()) if(it != sizes.end())
return *it; return *it;
return 1; return 1;
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/oper.hpp> #include <migraphx/gpu/oper.hpp>
...@@ -50,8 +49,6 @@ ...@@ -50,8 +49,6 @@
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -262,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu> ...@@ -262,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
}; };
MIGRAPHX_REGISTER_OP(hip_add_relu) MIGRAPHX_REGISTER_OP(hip_add_relu)
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid> struct hip_add_sigmoid : binary_device<hip_add_sigmoid, &device::add_sigmoid>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_sigmoid) MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
...@@ -1036,7 +1033,7 @@ struct find_gemm_pointwise ...@@ -1036,7 +1033,7 @@ struct find_gemm_pointwise
// const-fold input if not standard shape since rocblas can't handle it // const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard()) if(not c_ins->get_shape().standard())
{ {
auto c = op::contiguous{}; auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()}); auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data()); c_ins = m.add_literal(l.get_shape(), l.data());
} }
......
...@@ -176,8 +176,13 @@ void gemm_impl(context& ctx, ...@@ -176,8 +176,13 @@ void gemm_impl(context& ctx,
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1) if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
{ {
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
// the rocblas_gemm API handles inputs and output matrices as // the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
......
...@@ -24,22 +24,10 @@ ...@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp> #include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP #define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp> #include <string>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp> #include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,22 +24,10 @@ ...@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp> #include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment