Commit dd26f1aa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_optimization

parents 4e3d06ab 4a3e493c
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_set> #include <unordered_set>
......
...@@ -56,7 +56,8 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -56,7 +56,8 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
c = beta * c; c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0) if(alpha != 0.0)
{ {
c = c + alpha * a * b; c = c + alpha * a * b;
...@@ -116,10 +117,11 @@ template <class T> ...@@ -116,10 +117,11 @@ template <class T>
void migemm_impl( void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta) tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{ {
auto lens = cmat.get_shape().lens(); auto lens = amat.get_shape().lens();
std::size_t num_matrices = std::accumulate( bool batch_mul =
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); std::accumulate(
if(num_matrices == 1) lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
} }
......
...@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference ...@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0); assert((variance[c] + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / result(n, c, h, w) = gamma[c] * (buffer(n, c, h, w) - mean[c]) /
std::sqrt(variance(c) + epsilon) + std::sqrt(variance[c] + epsilon) +
bias(c); bias[c];
}); });
}); });
} }
...@@ -369,7 +369,15 @@ struct cpu_gemm ...@@ -369,7 +369,15 @@ struct cpu_gemm
{ {
op::dot op; op::dot op;
std::string name() const { return "cpu::dot"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
void fill_result(argument& result, argument& c) const void fill_result(argument& result, argument& c) const
{ {
...@@ -429,7 +437,9 @@ struct cpu_gemm ...@@ -429,7 +437,9 @@ struct cpu_gemm
} }
else else
{ {
fill_result(result, args[2]); visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
} }
migemm(result, args[0], args[1], op.alpha, op.beta); migemm(result, args[0], args[1], op.alpha, op.beta);
...@@ -437,33 +447,8 @@ struct cpu_gemm ...@@ -437,33 +447,8 @@ struct cpu_gemm
return result; return result;
} }
// 2 input cases // 2 input arguments
// first argument is 1-dim, pre-pend 1 at beginning migemm(result, args[0], args[1], op.alpha, 0.0f);
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
shape::type_t t = output_shape.type();
if(a_lens.size() == 1)
{
a_lens.insert(a_lens.begin(), 1);
out_lens.push_back(1);
if(out_lens.size() > 1)
{
std::swap(*out_lens.rbegin(), *(out_lens.rbegin() + 1));
}
}
if(b_lens.size() == 1)
{
b_lens.push_back(1);
out_lens.push_back(1);
}
migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
0.0f);
return result; return result;
} }
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
......
...@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1) if(conv.op.group > 1)
return false; return false;
if(conv.op.padding_mode != op::padding_mode_t::default_)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
...@@ -251,6 +253,12 @@ struct miopen_conv_bias ...@@ -251,6 +253,12 @@ struct miopen_conv_bias
fusion::op_t conv; fusion::op_t conv;
fusion::op_t bias; fusion::op_t bias;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::convolution::reflect(self.op, f);
}
miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b) miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
: op(c), f(input) : op(c), f(input)
{ {
...@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu ...@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
fusion::op_t bias; fusion::op_t bias;
fusion::op_t relu; fusion::op_t relu;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::convolution::reflect(self.op, f);
}
miopen_conv_bias_relu(op::convolution c, miopen_conv_bias_relu(op::convolution c,
const shape& input, const shape& input,
const shape& weights, const shape& weights,
......
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts> template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{ {
rocblas_sscal(std::forward<Ts>(xs)...); return rocblas_sscal(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{ {
rocblas_dscal(std::forward<Ts>(xs)...); return rocblas_dscal(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{ {
rocblas_haxpy(std::forward<Ts>(xs)...); return rocblas_haxpy(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{ {
rocblas_saxpy(std::forward<Ts>(xs)...); return rocblas_saxpy(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{ {
rocblas_daxpy(std::forward<Ts>(xs)...); return rocblas_daxpy(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{ {
rocblas_sdot(std::forward<Ts>(xs)...); return rocblas_sdot(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{ {
rocblas_ddot(std::forward<Ts>(xs)...); return rocblas_ddot(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemv(std::forward<Ts>(xs)...); return rocblas_sgemv(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemv(std::forward<Ts>(xs)...); return rocblas_dgemv(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{ {
rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...); return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm(std::forward<Ts>(xs)...); return rocblas_sgemm(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm(std::forward<Ts>(xs)...); return rocblas_dgemm(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs) rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{ {
rocblas_hgemm(std::forward<Ts>(xs)...); return rocblas_hgemm(std::forward<Ts>(xs)...);
} }
template <class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...) rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas"); MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
} }
...@@ -168,198 +169,9 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal ...@@ -168,198 +169,9 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1); std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
return op.compute_shape(orig_inputs); check_shapes{input_shapes}.not_broadcasted();
} return op.compute_shape(input_shapes);
void miopen_gemm::fill_result(const shape& output_shape,
const argument& result,
const argument& c) const
{
auto out_lens = output_shape.lens();
auto c_lens = c.get_shape().lens();
auto type_size = output_shape.type_size();
if(output_shape == c.get_shape())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpy(
to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
});
}
else if(c.single())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
return to_rocblas_type(as.from(arg.data() + offset_byte));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(result, i * type_size),
to_pointer(c),
c.get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
}
else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
{
auto m = out_lens[0];
auto n = out_lens[1];
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < m; ++i)
{
hipMemcpy(to_pointer(result, i * n * type_size),
to_pointer(c),
c.get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
}
// case of c_lens.size() == 2 && c_len[0] == out_lens[0]
else
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(result, i * type_size),
to_pointer(c, i / out_lens[1] * type_size),
type_size,
hipMemcpyDeviceToDevice);
}
});
}
}
argument miopen_gemm::batch_matmul(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
auto an_dim = a_lens.size();
auto bn_dim = b_lens.size();
auto outn_dim = out_lens.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
rocblas_int ldb = args[1].get_shape().strides()[transb ? bn_dim - 1 : bn_dim - 2];
rocblas_int ldc = args[2].get_shape().strides()[outn_dim - 2];
rocblas_int m = out_lens[outn_dim - 2];
rocblas_int n = out_lens[outn_dim - 1];
rocblas_int k = a_lens[an_dim - 1];
float beta = 0.0f;
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
{
std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
a_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
b_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
rocblas_int stride_c = m * n;
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
stride_b,
to_pointer(args[0]),
lda,
stride_a,
&beta_r,
to_pointer(args[2]),
ldc,
stride_c,
num_matrices);
});
}
else
{
std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + outn_dim - 2);
shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens};
shape out_batch_shape{t, out_batch_lens};
std::size_t a_len_diff = outn_dim - an_dim;
std::size_t b_len_diff = outn_dim - bn_dim;
shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
auto type_size = output_shape.type_size();
std::vector<std::size_t> a_idx(a_batch_lens.size());
std::vector<std::size_t> b_idx(b_batch_lens.size());
std::transform(out_idx.begin() + a_len_diff,
out_idx.end(),
a_batch_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(out_idx.begin() + b_len_diff,
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * b_ind * type_size),
ldb,
to_pointer(args[0], m * k * a_ind * type_size),
lda,
&beta_r,
to_pointer(args[2], m * n * out_ind * type_size),
ldc);
});
});
}
return args[2];
} }
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
...@@ -367,149 +179,60 @@ argument miopen_gemm::compute(context& ctx, ...@@ -367,149 +179,60 @@ argument miopen_gemm::compute(context& ctx,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
float beta = 0.0f;
if(is_3inputs) if(is_3inputs)
{ {
fill_result(output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[3].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
k * n,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
to_pointer(args[3]),
ldc,
m * n,
num_matrices);
});
return args[3];
}
// 2 input arguments cases
// vector inner product
if(output_shape.elements() == 1)
{
assert(args[0].get_shape().elements() == args[1].get_shape().elements());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as, hipMemcpyAsync(to_pointer(args[3]),
ctx.get_stream().get_rocblas(), to_pointer(args[2]),
args[1].get_shape().elements(), output_shape.bytes(),
to_pointer(args[0]), hipMemcpyDeviceToDevice,
1, ctx.get_stream().get());
to_pointer(args[1]),
1,
to_pointer(args[2]));
generic_rocblas_scal(
as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
}); });
beta = op.beta;
} }
// matrix * vector
else if(args[1].get_shape().lens().size() == 1)
{
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = a_lens.size() - 1;
bool transa = args[0].get_shape().transposed();
bool transb = false;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = 1;
rocblas_int ldc = 1;
rocblas_int m = a_lens[dim_0];
rocblas_int n = 1;
rocblas_int k = a_lens[dim_1];
float beta = 0.0f;
assert(a_lens.back() == args[1].get_shape().elements());
std::size_t batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm( auto a_lens = args[0].get_shape().lens();
as, auto b_lens = args[1].get_shape().lens();
ctx.get_stream().get_rocblas(), output_shape.visit_type([&](auto as) {
transb ? rocblas_operation_transpose : rocblas_operation_none, auto n_dim = output_shape.lens().size();
transa ? rocblas_operation_transpose : rocblas_operation_none, auto dim_1 = n_dim - 1;
n, auto dim_0 = n_dim - 2;
m, auto alpha_r = to_rocblas_type(as(op.alpha));
k, auto beta_r = to_rocblas_type(as(beta));
&alpha_r, bool transa = args[0].get_shape().transposed();
to_pointer(args[1]),
ldb,
0,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
to_pointer(args[2]),
ldc,
m * n,
batch_num);
});
}
// vector * matrix
else if(args[0].get_shape().lens().size() == 1)
{
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
std::size_t dim_0 = b_lens.size() - 2;
std::size_t dim_1 = b_lens.size() - 1;
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
bool transa = false; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int lda = a_lens[0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldc = b_lens[dim_1]; auto out_lens = output_shape.lens();
rocblas_int m = 1; rocblas_int m = out_lens[dim_0];
rocblas_int n = args[1].get_shape().lens()[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = a_lens[0]; rocblas_int k = args[0].get_shape().lens()[dim_1];
float beta = 0.0f; auto num_matrices = std::accumulate(
assert(b_lens[dim_0] == args[0].get_shape().elements()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
std::size_t batch_num = std::accumulate( if(num_matrices == 1)
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); {
generic_rocblas_gemm(as,
output_shape.visit_type([&](auto as) { ctx.get_stream().get_rocblas(),
auto alpha_r = to_rocblas_type(as(op.alpha)); transb ? rocblas_operation_transpose : rocblas_operation_none,
auto beta_r = to_rocblas_type(as(beta)); transa ? rocblas_operation_transpose : rocblas_operation_none,
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc);
}
else
{
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -524,21 +247,16 @@ argument miopen_gemm::compute(context& ctx, ...@@ -524,21 +247,16 @@ argument miopen_gemm::compute(context& ctx,
k * n, k * n,
to_pointer(args[0]), to_pointer(args[0]),
lda, lda,
0, m * k,
&beta_r, &beta_r,
to_pointer(args[2]), (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc, ldc,
m * n, m * n,
batch_num); num_matrices);
}); }
} });
// (batch) matrix multiplication
else
{
batch_matmul(ctx, output_shape, args);
}
return args[2]; return (is_3inputs ? args[3] : args[2]);
} }
} // namespace gpu } // namespace gpu
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/concat.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTIGUOUS_HPP #define MIGRAPHX_GUARD_RTGLIB_CONTIGUOUS_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/contiguous.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP #define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/dot.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context;
argument allocate_gpu(const shape& s, bool host = false); argument allocate_gpu(const shape& s, bool host = false);
argument to_gpu(const argument& arg, bool host = false); argument to_gpu(const argument& arg, bool host = false);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP #define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/pad.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_POOLING_HPP #define MIGRAPHX_GUARD_RTGLIB_POOLING_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/softmax.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/logsoftmax.hpp> #include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp> #include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <utility> #include <utility>
......
...@@ -17,20 +17,27 @@ ...@@ -17,20 +17,27 @@
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx) const std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
// clang-format off // clang-format off
return return
{ {
dead_code_elimination{},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
...@@ -53,13 +60,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -53,13 +60,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, enabled(MIGRAPHX_ENABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{}, eliminate_workspace{},
eliminate_allocation{"hip::allocate"}, eliminate_allocation{"hip::allocate"},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{},
eliminate_identity{}
}; };
// clang-format on // clang-format on
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment