Commit dd26f1aa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_optimization

parents 4e3d06ab 4a3e493c
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
......
......@@ -56,7 +56,8 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0)
{
c = c + alpha * a * b;
......@@ -116,10 +117,11 @@ template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
auto lens = cmat.get_shape().lens();
std::size_t num_matrices = std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
......
......@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
std::sqrt(variance(c) + epsilon) +
bias(c);
assert((variance[c] + epsilon) > 0);
result(n, c, h, w) = gamma[c] * (buffer(n, c, h, w) - mean[c]) /
std::sqrt(variance[c] + epsilon) +
bias[c];
});
});
}
......@@ -369,7 +369,15 @@ struct cpu_gemm
{
op::dot op;
std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
void fill_result(argument& result, argument& c) const
{
......@@ -429,7 +437,9 @@ struct cpu_gemm
}
else
{
fill_result(result, args[2]);
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
......@@ -437,33 +447,8 @@ struct cpu_gemm
return result;
}
// 2 input cases
// first argument is 1-dim, pre-pend 1 at beginning
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
shape::type_t t = output_shape.type();
if(a_lens.size() == 1)
{
a_lens.insert(a_lens.begin(), 1);
out_lens.push_back(1);
if(out_lens.size() > 1)
{
std::swap(*out_lens.rbegin(), *(out_lens.rbegin() + 1));
}
}
if(b_lens.size() == 1)
{
b_lens.push_back(1);
out_lens.push_back(1);
}
migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
0.0f);
// 2 input arguments
migemm(result, args[0], args[1], op.alpha, 0.0f);
return result;
}
......
......@@ -2,7 +2,6 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......
......@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1)
return false;
if(conv.op.padding_mode != op::padding_mode_t::default_)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false;
auto op = conv.op;
......@@ -251,6 +253,12 @@ struct miopen_conv_bias
fusion::op_t conv;
fusion::op_t bias;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::convolution::reflect(self.op, f);
}
miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
: op(c), f(input)
{
......@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
fusion::op_t bias;
fusion::op_t relu;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::convolution::reflect(self.op, f);
}
miopen_conv_bias_relu(op::convolution c,
const shape& input,
const shape& weights,
......
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
rocblas_sscal(std::forward<Ts>(xs)...);
return rocblas_sscal(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
rocblas_dscal(std::forward<Ts>(xs)...);
return rocblas_dscal(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
rocblas_haxpy(std::forward<Ts>(xs)...);
return rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
rocblas_saxpy(std::forward<Ts>(xs)...);
return rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
rocblas_daxpy(std::forward<Ts>(xs)...);
return rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
rocblas_sdot(std::forward<Ts>(xs)...);
return rocblas_sdot(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
rocblas_ddot(std::forward<Ts>(xs)...);
return rocblas_ddot(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
rocblas_sgemv(std::forward<Ts>(xs)...);
return rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
rocblas_dgemv(std::forward<Ts>(xs)...);
return rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
return rocblas_sgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
return rocblas_dgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
return rocblas_hgemm(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
}
......@@ -168,198 +169,9 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1);
return op.compute_shape(orig_inputs);
}
void miopen_gemm::fill_result(const shape& output_shape,
const argument& result,
const argument& c) const
{
auto out_lens = output_shape.lens();
auto c_lens = c.get_shape().lens();
auto type_size = output_shape.type_size();
if(output_shape == c.get_shape())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpy(
to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
});
}
else if(c.single())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
return to_rocblas_type(as.from(arg.data() + offset_byte));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(result, i * type_size),
to_pointer(c),
c.get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
}
else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
{
auto m = out_lens[0];
auto n = out_lens[1];
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < m; ++i)
{
hipMemcpy(to_pointer(result, i * n * type_size),
to_pointer(c),
c.get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
}
// case of c_lens.size() == 2 && c_len[0] == out_lens[0]
else
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(result, i * type_size),
to_pointer(c, i / out_lens[1] * type_size),
type_size,
hipMemcpyDeviceToDevice);
}
});
}
}
argument miopen_gemm::batch_matmul(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
auto an_dim = a_lens.size();
auto bn_dim = b_lens.size();
auto outn_dim = out_lens.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
rocblas_int ldb = args[1].get_shape().strides()[transb ? bn_dim - 1 : bn_dim - 2];
rocblas_int ldc = args[2].get_shape().strides()[outn_dim - 2];
rocblas_int m = out_lens[outn_dim - 2];
rocblas_int n = out_lens[outn_dim - 1];
rocblas_int k = a_lens[an_dim - 1];
float beta = 0.0f;
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
{
std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
a_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
b_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
rocblas_int stride_c = m * n;
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
stride_b,
to_pointer(args[0]),
lda,
stride_a,
&beta_r,
to_pointer(args[2]),
ldc,
stride_c,
num_matrices);
});
}
else
{
std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + outn_dim - 2);
shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens};
shape out_batch_shape{t, out_batch_lens};
std::size_t a_len_diff = outn_dim - an_dim;
std::size_t b_len_diff = outn_dim - bn_dim;
shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
auto type_size = output_shape.type_size();
std::vector<std::size_t> a_idx(a_batch_lens.size());
std::vector<std::size_t> b_idx(b_batch_lens.size());
std::transform(out_idx.begin() + a_len_diff,
out_idx.end(),
a_batch_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(out_idx.begin() + b_len_diff,
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * b_ind * type_size),
ldb,
to_pointer(args[0], m * k * a_ind * type_size),
lda,
&beta_r,
to_pointer(args[2], m * n * out_ind * type_size),
ldc);
});
});
}
return args[2];
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
}
argument miopen_gemm::compute(context& ctx,
......@@ -367,149 +179,60 @@ argument miopen_gemm::compute(context& ctx,
const std::vector<argument>& args) const
{
bool is_3inputs = (args.size() == 4);
float beta = 0.0f;
if(is_3inputs)
{
fill_result(output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[3].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
k * n,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
to_pointer(args[3]),
ldc,
m * n,
num_matrices);
});
return args[3];
}
// 2 input arguments cases
// vector inner product
if(output_shape.elements() == 1)
{
assert(args[0].get_shape().elements() == args[1].get_shape().elements());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as,
ctx.get_stream().get_rocblas(),
args[1].get_shape().elements(),
to_pointer(args[0]),
1,
to_pointer(args[1]),
1,
to_pointer(args[2]));
generic_rocblas_scal(
as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
hipMemcpyAsync(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice,
ctx.get_stream().get());
});
beta = op.beta;
}
// matrix * vector
else if(args[1].get_shape().lens().size() == 1)
{
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = a_lens.size() - 1;
bool transa = args[0].get_shape().transposed();
bool transb = false;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = 1;
rocblas_int ldc = 1;
rocblas_int m = a_lens[dim_0];
rocblas_int n = 1;
rocblas_int k = a_lens[dim_1];
float beta = 0.0f;
assert(a_lens.back() == args[1].get_shape().elements());
std::size_t batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
0,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
to_pointer(args[2]),
ldc,
m * n,
batch_num);
});
}
// vector * matrix
else if(args[0].get_shape().lens().size() == 1)
{
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = b_lens.size() - 2;
std::size_t dim_1 = b_lens.size() - 1;
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
bool transa = false;
rocblas_int lda = a_lens[0];
rocblas_int ldb = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
rocblas_int ldc = b_lens[dim_1];
rocblas_int m = 1;
rocblas_int n = args[1].get_shape().lens()[dim_1];
rocblas_int k = a_lens[0];
float beta = 0.0f;
assert(b_lens[dim_0] == args[0].get_shape().elements());
std::size_t batch_num = std::accumulate(
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if(num_matrices == 1)
{
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc);
}
else
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
......@@ -524,21 +247,16 @@ argument miopen_gemm::compute(context& ctx,
k * n,
to_pointer(args[0]),
lda,
0,
m * k,
&beta_r,
to_pointer(args[2]),
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc,
m * n,
batch_num);
});
}
// (batch) matrix multiplication
else
{
batch_matmul(ctx, output_shape, args);
}
num_matrices);
}
});
return args[2];
return (is_3inputs ? args[3] : args[2]);
}
} // namespace gpu
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/batch_norm.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/concat.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTIGUOUS_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......
......@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/dot.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
argument allocate_gpu(const shape& s, bool host = false);
argument to_gpu(const argument& arg, bool host = false);
......
......@@ -4,7 +4,7 @@
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
......
......@@ -2,7 +2,9 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraphx/manage_ptr.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <miopen/miopen.h>
#include <migraphx/config.hpp>
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_POOLING_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/softmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
......
......@@ -17,20 +17,27 @@
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx) const
{
auto& ctx = any_cast<context>(gctx);
// clang-format off
return
{
dead_code_elimination{},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
......@@ -53,13 +60,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, enabled(MIGRAPHX_ENABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"},
dead_code_elimination{},
eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
check_context<context>{},
dead_code_elimination{}
dead_code_elimination{},
eliminate_identity{}
};
// clang-format on
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment