Commit dd26f1aa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_optimization

parents 4e3d06ab 4a3e493c
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
......
......@@ -56,7 +56,8 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0)
{
c = c + alpha * a * b;
......@@ -116,10 +117,11 @@ template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
auto lens = cmat.get_shape().lens();
std::size_t num_matrices = std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment