Commit 635788d1 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into stridedslice_op

parents 9487dd3a af00eea8
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_sscal(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_dscal(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
return rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
return rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
return rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
return rocblas_sdot(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
return rocblas_ddot(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm(std::forward<Ts>(xs)...); return rocblas_sgemm(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm(std::forward<Ts>(xs)...); return rocblas_dgemm(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{ {
rocblas_hgemm(std::forward<Ts>(xs)...); return rocblas_hgemm(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
} }
...@@ -90,20 +169,40 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal ...@@ -90,20 +169,40 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
return op.compute_shape({inputs.at(0), inputs.at(1)}); check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
} }
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
float alpha = 1.0f; bool is_3inputs = (args.size() == 4);
float beta = 0.0f; float beta = 0.0f;
if(is_3inputs)
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpyAsync(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice,
ctx.get_stream().get());
});
beta = op.beta;
}
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
std::size_t n_dims = args[0].get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
...@@ -111,13 +210,31 @@ argument miopen_gemm::compute(context& ctx, ...@@ -111,13 +210,31 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto batch_num = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as, if(num_matrices == 1)
{
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc);
}
else
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -132,14 +249,14 @@ argument miopen_gemm::compute(context& ctx, ...@@ -132,14 +249,14 @@ argument miopen_gemm::compute(context& ctx,
lda, lda,
m * k, m * k,
&beta_r, &beta_r,
to_pointer(args[2]), (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc, ldc,
m * n, m * n,
batch_num); num_matrices);
}
}); });
return args[2]; return (is_3inputs ? args[3] : args[2]);
} }
} // namespace gpu } // namespace gpu
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/concat.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTIGUOUS_HPP #define MIGRAPHX_GUARD_RTGLIB_CONTIGUOUS_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/contiguous.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP #define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/dot.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context;
argument allocate_gpu(const shape& s, bool host = false); argument allocate_gpu(const shape& s, bool host = false);
argument to_gpu(const argument& arg, bool host = false); argument to_gpu(const argument& arg, bool host = false);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP #define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/pad.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment