Commit e27e3e13 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-poc

parents 1dd11890 d9578ba6
...@@ -26,8 +26,6 @@ on: ...@@ -26,8 +26,6 @@ on:
required: true required: true
default: '-s' default: '-s'
concurrency: benchmark
jobs: jobs:
release: release:
uses: rocmsoftwareplatform/migraphx-benchmark/.github/workflows/perf-test.yml@main uses: rocmsoftwareplatform/migraphx-benchmark/.github/workflows/perf-test.yml@main
......
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@d2cb9e580550e92ab75a0a417e7a4abd02a24edf -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -84,6 +84,12 @@ argument ...@@ -84,6 +84,12 @@ argument
Construct an argument from a python buffer. This can include numpy arrays. Construct an argument from a python buffer. This can include numpy arrays.
.. py:method:: data_ptr()
Returns the address to the underlying argument data.
:rtype: int
.. py:method:: get_shape() .. py:method:: get_shape()
Returns the shape of the argument. Returns the shape of the argument.
...@@ -113,7 +119,16 @@ argument ...@@ -113,7 +119,16 @@ argument
:param shape s: Shape of argument to fill. :param shape s: Shape of argument to fill.
:param int value: Value to fill in the argument. :param int value: Value to fill in the argument.
:rtype argument :rtype: argument
.. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy.
:param shape shape: Shape of the data stored in address.
:param long address: Memory address of data from another source
:rtype: argument
target target
------ ------
......
...@@ -50,8 +50,8 @@ struct layernorm_matcher ...@@ -50,8 +50,8 @@ struct layernorm_matcher
{ {
return f("div")(arg(0)(x_minus_mean()), return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")( arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__", .def(py::init([](py::buffer b) {
[](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request(); py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr); return migraphx::argument(to_shape(info), info.ptr);
}) }))
.def("get_shape", &migraphx::argument::get_shape) .def("get_shape", &migraphx::argument::get_shape)
.def("data_ptr",
[](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); })
.def("tolist", .def("tolist",
[](migraphx::argument& x) { [](migraphx::argument& x) {
py::list l{x.get_shape().elements()}; py::list l{x.get_shape().elements()};
......
...@@ -57,12 +57,14 @@ auto conv_const_weights() ...@@ -57,12 +57,14 @@ auto conv_const_weights()
auto reduction() { return match::name_contains("reduce"); } auto reduction() { return match::name_contains("reduce"); }
// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv struct find_mul_conv
{ {
auto matcher() const auto matcher() const
{ {
return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"), return match::name("mul")(
match::name("broadcast").bind("a"))); match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast", "multibroadcast").bind("a")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -72,14 +74,35 @@ struct find_mul_conv ...@@ -72,14 +74,35 @@ struct find_mul_conv
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"]; auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator()); const auto& a_input_lens = a_ins->inputs().front()->get_shape().lens();
if(broadcast_op.axis != 1)
std::size_t num_not_one_dims = std::count_if(
a_input_lens.cbegin(), a_input_lens.cend(), [](auto dim) { return dim != 1; });
if(num_not_one_dims > 1)
return;
// check broadcasted along channels
const auto& a_lens = a_ins->get_shape().lens();
const auto& a_strides = a_ins->get_shape().strides();
auto is_broadcasted_axis = [](auto len, auto stride) { return len == 1 or stride == 0; };
if(a_strides.at(1) != 1)
return;
if(not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;
if(not std::equal(a_lens.begin() + 2,
a_lens.end(),
a_strides.begin() + 2,
a_strides.end(),
is_broadcasted_axis))
return; return;
auto sq = m.insert_instruction(ins, make_op("squeeze"), a_ins->inputs().front());
auto new_a = m.insert_instruction( auto new_a = m.insert_instruction(
ins, ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), sq);
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins); auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction( auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul); ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
...@@ -985,20 +1008,35 @@ struct find_split_reshape ...@@ -985,20 +1008,35 @@ struct find_split_reshape
auto rsp_lens = rsp->get_shape().lens(); auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides(); auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]); rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size); auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end()) if(ait == rsp_strides.end())
{ {
return; return;
} }
int rsp_axis = std::distance(rsp_strides.begin(), ait); else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape // calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size()); std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis]; return is->get_shape().lens()[rsp_axis];
}); });
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
......
...@@ -271,6 +271,44 @@ struct find_nested_slice ...@@ -271,6 +271,44 @@ struct find_nested_slice
} }
}; };
struct find_concat_multibroadcasts
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast")));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
}))
{
return;
}
// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
return i->inputs().front();
});
// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
}
};
struct find_concat_transpose struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
...@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const
find_reshaper{}, find_reshaper{},
find_transpose{}, find_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{},
find_nested_convert{}, find_nested_convert{},
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
......
...@@ -324,26 +324,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}") ...@@ -324,26 +324,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
find_library(MLIRAPI_LIBRARY MLIRMIOpen # Find package rocMLIR
PATH_SUFFIXES find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
# Workaournd broken mlir install message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
lib/ lib/lib)
# REQUIRED is not supported before cmake 3.18
if(NOT MLIRAPI_LIBRARY)
message(FATAL_ERROR "libMLIRMIOpen not found")
else()
message(STATUS "Build with libMLIRMIOpen: " ${MLIRAPI_LIBRARY})
endif()
find_path(MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h)
# Workaround MLIR broken installation
find_path(MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR") target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_include_directories(migraphx_gpu SYSTEM PRIVATE ${MLIRAPI_HEADERS} ${MLIRAPI_HEADERS2}) target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
target_link_libraries(migraphx_gpu PUBLIC ${MLIRAPI_LIBRARY})
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
......
...@@ -138,16 +138,16 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -138,16 +138,16 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t groups = (n + local - 1) / local; std::size_t groups = (n + local - 1) / local;
std::size_t max_blocks = max_global / local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return nglobal; return std::min(nglobal, n);
}; };
} }
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{ {
size_t block_size = 128; const std::size_t min_block_size = 64;
while(block_size <= max_block_size and block_size <= n) const std::size_t base_block_size = 32;
block_size *= 2; auto block_size = (((n - 1) / base_block_size + 1)) * base_block_size;
return block_size / 2; return std::min(std::max(min_block_size, block_size), max_block_size);
} }
operation compile_hip_code_object(const std::string& content, hip_compile_options options) operation compile_hip_code_object(const std::string& content, hip_compile_options options)
......
...@@ -61,13 +61,25 @@ struct mlir_conv ...@@ -61,13 +61,25 @@ struct mlir_conv
MIGRAPHX_REGISTER_OP(mlir_conv); MIGRAPHX_REGISTER_OP(mlir_conv);
namespace { namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
if(ins->name() != "convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
return true;
}
struct find_conv_pointwise struct find_conv_pointwise
{ {
// Find a convolution followed by a pointwise operation. // Find a convolution followed by a pointwise operation.
auto matcher() const auto matcher() const
{ {
auto convolution = auto convolution =
match::skip(match::name("contiguous"))(match::name("convolution").bind("convolution")); match::skip(match::name("contiguous"))(is_mlir_conv().bind("convolution"));
return match::name("pointwise")(match::any_of[match::inputs()](convolution.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](convolution.bind("x")));
} }
......
...@@ -259,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu> ...@@ -259,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
}; };
MIGRAPHX_REGISTER_OP(hip_add_relu) MIGRAPHX_REGISTER_OP(hip_add_relu)
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid> struct hip_add_sigmoid : binary_device<hip_add_sigmoid, &device::add_sigmoid>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_sigmoid) MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
......
...@@ -176,8 +176,13 @@ void gemm_impl(context& ctx, ...@@ -176,8 +176,13 @@ void gemm_impl(context& ctx,
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1) if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
{ {
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
// the rocblas_gemm API handles inputs and output matrices as // the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
// NOLINTNEXTLINE
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
return inputs.back().elements() / (inputs.size() - 1);
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -52,7 +52,7 @@ __global__ void ${kernel}(${params}) ...@@ -52,7 +52,7 @@ __global__ void ${kernel}(${params})
{ {
auto idx = make_index(); auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
}); });
} }
...@@ -90,6 +90,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -90,6 +90,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel"); options.kernel_name = v.get("kernel", "layernorm_kernel");
auto eps = v.get("epsilon", 1e-12f);
auto src = interpolate_string(layernorm_kernel, auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
...@@ -99,7 +100,8 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -99,7 +100,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"post", v.get("post", std::string{"op::id{}"})}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})}, {"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}}); {"axis", to_string(axis)},
{"eps", to_string(eps)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -37,45 +37,91 @@ namespace migraphx { ...@@ -37,45 +37,91 @@ namespace migraphx {
template <class U> \ template <class U> \
constexpr array& operator op(const array<U, N>& x) \ constexpr array& operator op(const array<U, N>& x) \
{ \ { \
for(index_int i = 0; i < N; i++) \ array_detail::array_for_each(*this, x)([](auto& sy, auto sx) { sy op sx; }); \
d[i] op x[i]; \
return *this; \ return *this; \
} \ } \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
constexpr array& operator op(const U& x) \ constexpr array& operator op(const U& x) \
{ \ { \
for(index_int i = 0; i < N; i++) \ array_detail::array_for_each (*this)([&](auto& sy) { sy op x; }); \
d[i] op x; \
return *this; \ return *this; \
} \ } \
template <class U> \ template <class U> \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \ friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
{ \ { \
array<decltype(T {} binary_op U{}), N> z{}; \ array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \ array_detail::array_for_each(z, x, y)( \
z[i] = x[i] binary_op y[i]; \ [&](auto& sz, auto sx, auto sy) { sz = sx binary_op sy; }); \
return z; \ return z; \
} \ } \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const array& x, const U& y) \ friend constexpr auto operator binary_op(const array& x, const U& y) \
{ \ { \
array<decltype(T {} binary_op U{}), N> z{}; \ array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \ array_detail::array_for_each(z, x)([&](auto& sz, auto sx) { sz = sx binary_op y; }); \
z[i] = x[i] binary_op y; \
return z; \ return z; \
} \ } \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \ friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \ { \
array<decltype(T {} binary_op U{}), N> z{}; \ array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \ array_detail::array_for_each(z, y)([&](auto& sz, auto sy) { sz = x binary_op sy; }); \
z[i] = x binary_op y[i]; \
return z; \ return z; \
} }
namespace array_detail {
template <class T>
constexpr auto is_vectorizable()
{
return not is_same<T, bool>{} and (is_fundamental<T>{} or is_same<T, half>{});
}
template <class T>
__device__ auto& array2vec(T& x)
{
using value_type = typename T::value_type;
constexpr auto size = decltype(x.size()){};
using type = vec<value_type, size>;
if constexpr(is_const<T>{})
return reinterpret_cast<const type&>(x);
else
return reinterpret_cast<type&>(x);
}
template <class T, class... Ts>
constexpr auto array_for_each(T& x, Ts&... xs)
{
MIGRAPHX_ASSERT(((x.size() == xs.size()) and ...));
return [&](auto f) {
constexpr auto size = decltype(x.size()){};
if constexpr((is_vectorizable<typename T::value_type>() or
(is_vectorizable<typename Ts::value_type>() or ...)) and
size <= 8 and size > 1 and (size % 2 == 0))
{
if(__builtin_is_constant_evaluated())
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
else
{
using vec_type = std::remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
}
}
else
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
};
}
} // namespace array_detail
template <class T, index_int N> template <class T, index_int N>
struct array struct array
{ {
using value_type = T;
T d[N]; T d[N];
constexpr T& operator[](index_int i) constexpr T& operator[](index_int i)
{ {
...@@ -108,18 +154,13 @@ struct array ...@@ -108,18 +154,13 @@ struct array
constexpr T dot(const array& x) const constexpr T dot(const array& x) const
{ {
T result = 0; auto r = x * (*this);
for(index_int i = 0; i < N; i++) return r.reduce([](auto a, auto b) { return a + b; }, 0);
result += x[i] * d[i];
return result;
} }
constexpr T product() const constexpr T product() const
{ {
T result = 1; return reduce([](auto x, auto y) { return x * y; }, 1);
for(index_int i = 0; i < N; i++)
result *= d[i];
return result;
} }
constexpr T single(index_int width = 100) const constexpr T single(index_int width = 100) const
...@@ -134,6 +175,24 @@ struct array ...@@ -134,6 +175,24 @@ struct array
return result; return result;
} }
template <class F>
constexpr auto apply(F f) const
{
array<decltype(f(d[0])), N> result;
for(index_int i = 0; i < N; i++)
result[i] = f(d[i]);
return result;
}
template <class F>
constexpr auto reduce(F f, T init) const
{
T result = init;
for(index_int i = 0; i < N; i++)
result = f(result, d[i]);
return result;
}
MIGRAPHX_DEVICE_ARRAY_OP(+=, +) MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(-=, -) MIGRAPHX_DEVICE_ARRAY_OP(-=, -)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *) MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
...@@ -201,6 +260,11 @@ struct array ...@@ -201,6 +260,11 @@ struct array
} }
}; };
template <class T, class... Ts>
constexpr array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {x, static_cast<T>(xs)...};
}
template <class T, T... Xs> template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)> struct integral_const_array : array<T, sizeof...(Xs)>
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#ifndef MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
#define MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
namespace migraphx {
template <index_int Axis, class Output, class Input, class Start>
constexpr auto concat_slice(Output out, Input, Start)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
constexpr auto strides = get_shape_c<Output>{}.strides;
constexpr auto offset = return_c([] {
constexpr auto output_shape = get_shape_c<Output>{};
return Start{} * output_shape.strides[Axis];
});
constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s);
}
template <index_int Axis, class Input>
constexpr auto concat_ends(Input)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
return _c<lens[Axis]>;
}
template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs)
{
auto idx = make_index();
fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
...@@ -28,9 +28,60 @@ ...@@ -28,9 +28,60 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
#define MIGRAPHX_NGROUP ((MIGRAPHX_NGLOBAL + MIGRAPHX_NLOCAL - 1) / MIGRAPHX_NLOCAL)
#endif
inline __device__ __attribute__((const)) index_int compute_global_size()
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
// This actualy works even when global is not divisible by local size.
// This doesnt actually do a multiplicatiosn. Instead it calls a device
// function to get the global size, which is why it works.
return blockDim.x * gridDim.x; // NOLINT
#endif
}
// We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the
// size for the last group.
inline __device__ __attribute__((const)) index_int compute_local_size()
{
#ifdef MIGRAPHX_NLOCAL
const auto nlocal = MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#else
const auto ngroup = gridDim.x; // NOLINT
#endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
struct index struct index
{ {
index_int global = 0; index_int global = 0;
...@@ -38,20 +89,44 @@ struct index ...@@ -38,20 +89,44 @@ struct index
index_int group = 0; index_int group = 0;
#ifdef MIGRAPHX_NGLOBAL #ifdef MIGRAPHX_NGLOBAL
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; } constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const
{
static_assert(MIGRAPHX_NGLOBAL > 0, "Global size must be greater than 0");
return {};
}
#else #else
__device__ index_int nglobal() const __device__ index_int nglobal() const
{ {
return blockDim.x * gridDim.x; // NOLINT MIGRAPHX_ASSERT(compute_global_size() > 0);
return compute_global_size(); // NOLINT
} }
#endif #endif
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_HAS_CONST_LOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; } constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const
{
static_assert(MIGRAPHX_NLOCAL > 0, "Local size must be greater than 0");
return {};
}
#else #else
__device__ index_int nlocal() const __device__ index_int nlocal() const
{ {
return blockDim.x; // NOLINT #ifdef MIGRAPHX_NGROUP
static_assert((MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL != 0) and (MIGRAPHX_NGROUP > 1),
"Local size should be const");
#endif
MIGRAPHX_ASSERT(compute_local_size() > 0);
return compute_local_size(); // NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> max_nlocal() const { return {}; }
#else
__device__ index_int max_nlocal() const
{
MIGRAPHX_ASSERT(blockDim.x > 0);
return blockDim.x;
} }
#endif #endif
template <class N, class Stride> template <class N, class Stride>
...@@ -63,6 +138,7 @@ struct index ...@@ -63,6 +138,7 @@ struct index
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1) max_stride_iterations(n, stride) == 1)
{ {
......
...@@ -29,6 +29,12 @@ ...@@ -29,6 +29,12 @@
namespace migraphx { namespace migraphx {
template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
return a.apply([&](auto x) { return vec_reduce(x, op); });
}
template <index_int Axis, template <index_int Axis,
class F, class F,
class BinOp, class BinOp,
...@@ -37,46 +43,46 @@ template <index_int Axis, ...@@ -37,46 +43,46 @@ template <index_int Axis,
class Input2, class Input2,
class... Inputs> class... Inputs>
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) { auto means =
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements}; auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2); })(input1, input2);
};
// mean(x) auto mean_x = means[0];
auto mean_x = mean(op); auto mean_x2 = means[1];
// mean(m ^ 2) auto variance = mean_x2 - (mean_x * mean_x);
auto mean_m2 = mean([&](auto x1, auto x2) { value_type eps_val = eps; // implicit conversion for eps
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x; auto x = op(x1, x2);
// m * rsqrt(mean(m ^ 2) + 1e-12) auto m = x - mean_x;
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input1, input2, inputs...);
}); });
} }
template <index_int Axis, class F, class Output, class Input, class... Inputs> template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs) __device__ void layernorm(F compute, float eps, Output output, Input input, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>( generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...); compute, [](auto x, auto) { return x; }, eps, output, input, input, inputs...);
} }
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs> template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void __device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs) add_layernorm(F compute, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>( generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...); compute, [](auto x1, auto x2) { return x1 + x2; }, eps, output, input1, input2, inputs...);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max) ...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min) MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul) MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return y; return y;
} }
#else #else
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x; buffer[idx.local] = x;
...@@ -201,12 +202,9 @@ struct block ...@@ -201,12 +202,9 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
op, return vec_reduce(read(x[j], xs[j]...), op);
init, });
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
......
...@@ -151,7 +151,6 @@ struct miopen_apply ...@@ -151,7 +151,6 @@ struct miopen_apply
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("clip"); add_extend_op("clip");
add_extend_op("concat");
add_extend_op("convert"); add_extend_op("convert");
add_extend_op("elu"); add_extend_op("elu");
add_extend_op("gather"); add_extend_op("gather");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment