Commit a96585d6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'ref_gemm_par' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into layernorm_half2

parents c1ec929c 9b19b73f
......@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const { return {s, this->m_data}; }
argument argument::reshape(const shape& s) const
{
assert(s.element_space() <= this->get_shape().element_space());
return {s, this->m_data};
}
argument::data_t argument::data_t::share() const
{
......
......@@ -36,7 +36,6 @@ struct as_shape
{
return args.front().reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -67,7 +67,6 @@ struct broadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -51,7 +51,6 @@ struct flatten
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -69,7 +69,6 @@ struct multibroadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -75,7 +75,6 @@ struct reshape
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -40,7 +40,6 @@ struct scalar
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -84,7 +84,6 @@ struct squeeze
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -72,8 +72,6 @@ struct step
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -56,7 +56,6 @@ struct transpose
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -81,7 +81,6 @@ struct unsqueeze
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -226,6 +226,8 @@ struct shape
const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
std::size_t element_space() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
......
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraphx {
......@@ -74,8 +74,10 @@ void migemm_impl(
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment